import argparse
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import logging
from datetime import datetime
import matplotlib.pyplot as plt
from dataset.CremadDataset import CremadDataset
from dataset.AVEDataset import AVEDataset
from dataset.KSDataset import KSDataset
from models.basic_model import AVClassifier
from utils.utils import setup_seed, weight_init

def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='CREMAD', type=str,
                        help='KineticSound, CREMAD, AVE')
    parser.add_argument('--fusion_method', default='concat', type=str,
                        choices=['sum', 'concat', 'film'])
    parser.add_argument('--fps', default=1, type=int)
    parser.add_argument('--use_video_frames', default=3, type=int)
    parser.add_argument('--audio_path', default='/', type=str)
    parser.add_argument('--visual_path', default='/', type=str)
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--optimizer', default='sgd', type=str, choices=['sgd', 'adam'])
    parser.add_argument('--learning_rate', default=0.001, type=float, help='initial learning rate')
    parser.add_argument('--lr_decay_step', default=70, type=int, help='where learning rate decays')
    parser.add_argument('--lr_decay_ratio', default=0.1, type=float, help='decay coefficient')
    parser.add_argument('--ckpt_path', required=True, type=str, help='path to save trained models')
    parser.add_argument('--train', action='store_true', help='turn on train mode')
    parser.add_argument('--use_tensorboard', default=False, type=bool, help='whether to visualize')
    parser.add_argument('--tensorboard_path', type=str, help='path to save tensorboard logs')
    parser.add_argument('--random_seed', default=0, type=int)
    parser.add_argument('--gpu_ids', default='0, 1', type=str, help='GPU ids')
    parser.add_argument('--drop_mode', default='spatial', choices=['spatial', 'channel'])
    parser.add_argument('--drop_init', default=0.25, type=float)
    parser.add_argument('--p_max', default=0.3, type=float)
    parser.add_argument('--gradcam_drop', action='store_true')
    return parser.parse_args()

def train_epoch(args, epoch, model, device, dataloader, optimizer, scheduler, writer=None):
    criterion = nn.CrossEntropyLoss()
    softmax = nn.Softmax(dim=1)
    relu = nn.ReLU()
    tanh = nn.Tanh()

    model.train()
    print("Start training ... ")
    loss_sum, loss_a_sum, loss_v_sum = 0., 0., 0.
    for step, (spec, image, label) in enumerate(dataloader):
        spec = spec.to(device)
        image = image.to(device)
        label = label.to(device)

        optimizer.zero_grad()

        a_feat, v_feat, logits = model(
            spec.unsqueeze(1).float(),
            image.float(),
            apply_drop=args.gradcam_drop
        )
        if args.fusion_method == 'sum':
            out_v = torch.mm(v_feat, model.module.fusion_module.fc_y.weight.t()) + model.module.fusion_module.fc_y.bias
            out_a = torch.mm(a_feat, model.module.fusion_module.fc_x.weight.t()) + model.module.fusion_module.fc_x.bias
        elif args.fusion_method == 'concat':
            W = model.module.fusion_module.fc_out.weight
            b = model.module.fusion_module.fc_out.bias
            half = W.size(1) // 2
            out_v = torch.mm(v_feat, W[:, half:].t()) + b / 2
            out_a = torch.mm(a_feat, W[:, :half].t()) + b / 2
        elif args.fusion_method == 'tbn':
            out_a = model.module.audio_fc(a_feat)
            out_v = model.module.visual_fc(v_feat)
        elif args.fusion_method == 'psp':
            out_a = model.module.audio_fc(a_feat)
            out_v = model.module.visual_fc(v_feat)
        else:
            fc = model.module.fusion_module.fc_out
            out_v = fc(v_feat)
            out_a = fc(a_feat)

        loss = criterion(logits, label)
        loss_v = criterion(out_v, label)
        loss_a = criterion(out_a, label)
        loss.backward()

        score_v = sum(softmax(out_v)[i, label[i]] for i in range(out_v.size(0)))
        score_a = sum(softmax(out_a)[i, label[i]] for i in range(out_a.size(0)))
        ratio_v = score_v / score_a
        ratio_a = 1 / ratio_v

        if args.gradcam_drop:
            p_max = args.p_max
            p_v = float(torch.sigmoid(torch.log(ratio_v))) * p_max
            p_a = float(torch.sigmoid(torch.log(ratio_a))) * p_max

            tgt = model.module if isinstance(model, torch.nn.DataParallel) else model
            tgt.drop_v.update_p(p_v)
            tgt.drop_a.update_p(p_a)

            if args.use_tensorboard:
                it = epoch * len(dataloader) + step
                writer.add_scalar('drop_p_v', p_v, it)
                writer.add_scalar('drop_p_a', p_a, it)
        optimizer.step()

        loss_sum += loss.item()
        loss_a_sum += loss_a.item()
        loss_v_sum += loss_v.item()

    scheduler.step()
    N = len(dataloader)
    return loss_sum / N, loss_a_sum / N, loss_v_sum / N


def valid(args, model, device, dataloader):
    softmax = nn.Softmax(dim=1)

    if args.dataset == 'KineticSound':
        n_classes = 31
    elif args.dataset == 'CREMAD':
        n_classes = 6
    elif args.dataset == 'AVE':
        n_classes = 28
    else:
        raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))
    with torch.no_grad():
        model.eval()
        # TODO: more flexible
        num = [0.0 for _ in range(n_classes)]
        acc = [0.0 for _ in range(n_classes)]
        acc_a = [0.0 for _ in range(n_classes)]
        acc_v = [0.0 for _ in range(n_classes)]

        for step, (spec, image, label) in enumerate(dataloader):

            spec = spec.to(device)
            image = image.to(device)
            label = label.to(device)

            a, v, out = model(spec.unsqueeze(1).float(), image.float())

            if args.fusion_method == 'sum':
                fusion = model.module.fusion_module
                out_v = v @ fusion.fc_y.weight.T + fusion.fc_y.bias
                out_a = a @ fusion.fc_x.weight.T + fusion.fc_x.bias

            elif args.fusion_method == 'concat':
                fusion = model.module.fusion_module
                W, b = fusion.fc_out.weight, fusion.fc_out.bias  # W:(N,1024)
                out_v = v @ W[:, 512:].T + b / 2
                out_a = a @ W[:, :512].T + b / 2

            else:
                fusion = model.module.fusion_module
                fc = fusion.fc_out  # Linear(512,N)
                out_v = fc(v)
                out_a = fc(a)
            prediction = softmax(out)
            pred_v = softmax(out_v)
            pred_a = softmax(out_a)

            for i in range(image.shape[0]):

                ma = np.argmax(prediction[i].cpu().data.numpy())
                v = np.argmax(pred_v[i].cpu().data.numpy())
                a = np.argmax(pred_a[i].cpu().data.numpy())
                num[label[i]] += 1.0

                # pdb.set_trace()
                if np.asarray(label[i].cpu()) == ma:
                    acc[label[i]] += 1.0
                if np.asarray(label[i].cpu()) == v:
                    acc_v[label[i]] += 1.0
                if np.asarray(label[i].cpu()) == a:
                    acc_a[label[i]] += 1.0

    return sum(acc) / sum(num), sum(acc_a) / sum(num), sum(acc_v) / sum(num)


def main():
    args = get_arguments()
    print(args)

    setup_seed(args.random_seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
    gpu_ids = list(range(torch.cuda.device_count()))

    device = torch.device('cuda:0')

    model = AVClassifier(args)

    model.apply(weight_init)
    model.to(device)

    model = torch.nn.DataParallel(model, device_ids=gpu_ids)

    model.cuda()

    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, args.lr_decay_ratio)

    if args.dataset == 'KineticSound':
        train_dataset = KSDataset(args, mode='train')
        test_dataset = KSDataset(args, mode='test')
    elif args.dataset == 'CREMAD':
        train_dataset = CremadDataset(args, mode='train')
        test_dataset = CremadDataset(args, mode='test')
    elif args.dataset == 'AVE':
        train_dataset = AVEDataset(args, mode='train')
        test_dataset = AVEDataset(args, mode='test')
    else:
        raise NotImplementedError('Incorrect dataset name {}! '
                                  'Only support VGGSound, KineticSound and CREMA-D for now!'.format(args.dataset))

    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size,
                                  shuffle=True, num_workers=32, pin_memory=True)

    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size,
                                 shuffle=False, num_workers=32, pin_memory=True)
    train_acc_list = []
    train_accv_list = []
    train_acca_list = []
    if args.train:
        best_acc = 0.0
        for epoch in range(args.epochs):
            print('Epoch: {}: '.format(epoch))
            if args.use_tensorboard:
                writer_path = os.path.join(args.tensorboard_path, args.dataset)
                if not os.path.exists(writer_path):
                    os.mkdir(writer_path)
                log_name = '{}_{}'.format(args.fusion_method, args.modulation)
                writer = SummaryWriter(os.path.join(writer_path, log_name))

                batch_loss, batch_loss_a, batch_loss_v = train_epoch(args, epoch, model, device,
                                                                     train_dataloader, optimizer, scheduler)
                acc, acc_a, acc_v = valid(args, model, device, test_dataloader)

                writer.add_scalars('Loss', {'Total Loss': batch_loss,
                                            'Audio Loss': batch_loss_a,
                                            'Visual Loss': batch_loss_v}, epoch)

                writer.add_scalars('Evaluation', {'Total Accuracy': acc,
                                                  'Audio Accuracy': acc_a,
                                                  'Visual Accuracy': acc_v}, epoch)

            else:
                batch_loss, batch_loss_a, batch_loss_v = train_epoch(args, epoch, model, device,
                                                                     train_dataloader, optimizer, scheduler)
                acc, acc_a, acc_v = valid(args, model, device, test_dataloader)

            train_acc_list.append(acc)
            train_acca_list.append(acc_a)
            train_accv_list.append(acc_v)

            if acc > best_acc:
                best_acc = float(acc)

                if not os.path.exists(args.ckpt_path):
                    os.mkdir(args.ckpt_path)
                print("Loss: {:.3f}, Acc: {:.3f}".format(batch_loss, acc))
                print("Audio Acc: {:.3f}， Visual Acc: {:.3f} ".format(acc_a, acc_v))
                logging.info(f'Epoch {epoch} | acc = {acc:.4f} | audio_acc = {acc_a:.4f} | video_acc = {acc_v:.4f}')

            else:
                print("Loss: {:.3f}, Acc: {:.3f}, Best Acc: {:.3f}".format(batch_loss, acc, best_acc))
                print("Audio Acc: {:.3f}， Visual Acc: {:.3f} ".format(acc_a, acc_v))
                logging.info(f'Epoch {epoch} | acc = {acc:.4f} | audio_acc = {acc_a:.4f} | video_acc = {acc_v:.4f}')



    else:
        loaded_dict = torch.load(args.ckpt_path)
        modulation = loaded_dict['modulation']
        fusion = loaded_dict['fusion']
        state_dict = loaded_dict['model']

        assert modulation == args.modulation, 'inconsistency between modulation method of loaded model and args !'
        assert fusion == args.fusion_method, 'inconsistency between fusion method of loaded model and args !'

        model = model.load_state_dict(state_dict)
        print('Trained model loaded!')

        acc, acc_a, acc_v = valid(args, model, device, test_dataloader)
        print('Accuracy: {}, accuracy_a: {}, accuracy_v: {}'.format(acc, acc_a, acc_v))


if __name__ == "__main__":
    main()
