import os
import sys

sys.path.insert(1, os.path.join(sys.path[0], "../utils"))
import numpy as np
import argparse
import h5py
import math
import time
import logging
import matplotlib.pyplot as plt

import torch

torch.backends.cudnn.benchmark = True
torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data

from utilities import get_filename
from models import *
import config


class Transfer_Cnn14(nn.Module):
    def __init__(
        self,
        sample_rate,
        window_size,
        hop_size,
        mel_bins,
        fmin,
        fmax,
        classes_num,
        freeze_base,
    ):
        """Classifier for a new task using pretrained Cnn14 as a sub module."""
        super(Transfer_Cnn14, self).__init__()
        audioset_classes_num = 527

        self.base = Cnn14(
            sample_rate,
            window_size,
            hop_size,
            mel_bins,
            fmin,
            fmax,
            audioset_classes_num,
        )

        # Transfer to another task layer
        self.fc_transfer = nn.Linear(2048, classes_num, bias=True)

        if freeze_base:
            # Freeze AudioSet pretrained layers
            for param in self.base.parameters():
                param.requires_grad = False

        self.init_weights()

    def init_weights(self):
        init_layer(self.fc_transfer)

    def load_from_pretrain(self, pretrained_checkpoint_path):
        checkpoint = torch.load(pretrained_checkpoint_path)
        self.base.load_state_dict(checkpoint["model"])

    def forward(self, input, mixup_lambda=None):
        """Input: (batch_size, data_length)"""
        output_dict = self.base(input, mixup_lambda)
        embedding = output_dict["embedding"]

        clipwise_output = torch.log_softmax(self.fc_transfer(embedding), dim=-1)
        output_dict["clipwise_output"] = clipwise_output

        return output_dict


def train(args):

    # Arugments & parameters
    sample_rate = args.sample_rate
    window_size = args.window_size
    hop_size = args.hop_size
    mel_bins = args.mel_bins
    fmin = args.fmin
    fmax = args.fmax
    model_type = args.model_type
    pretrained_checkpoint_path = args.pretrained_checkpoint_path
    freeze_base = args.freeze_base
    device = "cuda" if (args.cuda and torch.cuda.is_available()) else "cpu"

    classes_num = config.classes_num
    pretrain = True if pretrained_checkpoint_path else False

    # Model
    Model = eval(model_type)
    model = Model(
        sample_rate,
        window_size,
        hop_size,
        mel_bins,
        fmin,
        fmax,
        classes_num,
        freeze_base,
    )

    # Load pretrained model
    if pretrain:
        logging.info("Load pretrained model from {}".format(pretrained_checkpoint_path))
        model.load_from_pretrain(pretrained_checkpoint_path)

    # Parallel
    print("GPU number: {}".format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)

    if "cuda" in device:
        model.to(device)

    print("Load pretrained model successfully!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Example of parser. ")
    subparsers = parser.add_subparsers(dest="mode")

    # Train
    parser_train = subparsers.add_parser("train")
    parser_train.add_argument("--sample_rate", type=int, required=True)
    parser_train.add_argument("--window_size", type=int, required=True)
    parser_train.add_argument("--hop_size", type=int, required=True)
    parser_train.add_argument("--mel_bins", type=int, required=True)
    parser_train.add_argument("--fmin", type=int, required=True)
    parser_train.add_argument("--fmax", type=int, required=True)
    parser_train.add_argument("--model_type", type=str, required=True)
    parser_train.add_argument("--pretrained_checkpoint_path", type=str)
    parser_train.add_argument("--freeze_base", action="store_true", default=False)
    parser_train.add_argument("--cuda", action="store_true", default=False)

    # Parse arguments
    args = parser.parse_args()
    args.filename = get_filename(__file__)

    if args.mode == "train":
        train(args)

    else:
        raise Exception("Error argument!")
