from utils.utils import setup_seed
from dataset.av_dataset import AVDataset_CD, AVDataset_KS, AVDataset_AVE, AVDataset_MOSI, AVDataset_MELD, AVDataset_IEMOCAP
import copy
from torch.utils.data import DataLoader
from models.models import AVClassifierGOALcmumosi
from sklearn import metrics
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch
from min_norm_solvers import MinNormSolver
import numpy as np
from tqdm import tqdm
import argparse
import os
import pickle
from operator import mod

import logging
import datetime
import sys


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='IEMOCAP', type=str,
                        help='MOSI,MELD, IEMOCAP')
    parser.add_argument('--model', default='model', type=str)
    parser.add_argument('--n_classes', default=6, type=int)
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--optimizer', default='sgd',
                        type=str, choices=['sgd', 'adam'])
    parser.add_argument('--learning_rate', default=0.004, type=float, help='initial learning rate')
    parser.add_argument('--lr_decay_step', default=30, type=int, help='where learning rate decays')
    parser.add_argument('--lr_decay_ratio', default=0.1,
                        type=float, help='decay coefficient')
    parser.add_argument('--ckpt_path', default='log_cd',
                        type=str, help='path to save trained models')
    parser.add_argument('--train', default=True,
                        help='turn on train mode')
    parser.add_argument('--clip_grad', action='store_true',
                        help='turn on train mode')
    parser.add_argument('--use_tensorboard', default=True,
                        type=bool, help='whether to visualize')
    parser.add_argument('--tensorboard_path', default='log_cd',
                        type=str, help='path to save tensorboard logs')
    parser.add_argument('--random_seed', default=0, type=int)
    parser.add_argument('--gpu_ids', default='1',
                        type=str, help='GPU ids')


    return parser.parse_args()



def train_epoch(args, epoch, model, device, dataloader, optimizer, scheduler, writer=None):
    criterion = nn.CrossEntropyLoss()
    model.train()
    print("Start training ... ")

    _loss = 0


    for step, (spec, images,text1,text2, label) in tqdm(enumerate(dataloader)):


        optimizer.zero_grad()
        images = images.to(device)
        spec = spec.to(device)
        label = label.to(device)
        text =[]
        text.append(text1)
        text.append(text2)
        out,out_a,out_v,out_t = model(spec.float(), images.float(), text)

        loss_mm = criterion(out, label)

        loss_a=criterion(out_a,label)

        loss_v=criterion(out_v,label)

        loss_t = criterion(out_t,label)

        loss=loss_mm+loss_a+loss_v+loss_t

        loss.backward()



        optimizer.step()
        _loss += loss.item()



    return _loss / len(dataloader)


def valid(args, model, device, dataloader):
    n_classes = args.n_classes

    with torch.no_grad():
        model.eval()
        num = [0.0 for _ in range(n_classes)]
        acc = [0.0 for _ in range(n_classes)]
        acc_a = [0.0 for _ in range(n_classes)]
        acc_v = [0.0 for _ in range(n_classes)]
        acc_t = [0.0 for _ in range(n_classes)]

        all_labels = []
        all_preds = []
        all_preds_a = []
        all_preds_v = []
        all_preds_t = []

        for step, (spec, images, text1, text2, label) in tqdm(enumerate(dataloader)):
            images = images.to(device)
            spec = spec.to(device)
            label = label.to(device)
            text = []
            text.append(text1)
            text.append(text2)
            prediction_all = model(spec.float(), images.float(), text)

            prediction = prediction_all[0]
            prediction_audio = prediction_all[1]
            prediction_visual = prediction_all[2]
            prediction_text = prediction_all[3]

            batch_labels = label.cpu().numpy()
            batch_preds = torch.argmax(prediction, dim=1).cpu().numpy()
            batch_preds_a = torch.argmax(prediction_audio, dim=1).cpu().numpy()
            batch_preds_v = torch.argmax(prediction_visual, dim=1).cpu().numpy()
            batch_preds_t = torch.argmax(prediction_text, dim=1).cpu().numpy()

            all_labels.extend(batch_labels)
            all_preds.extend(batch_preds)
            all_preds_a.extend(batch_preds_a)
            all_preds_v.extend(batch_preds_v)
            all_preds_t.extend(batch_preds_t)

            for i, item in enumerate(label):

                ma = prediction[i].cpu().data.numpy()
                index_ma = np.argmax(ma)
                num[label[i]] += 1.0
                if index_ma == label[i]:
                    acc[label[i]] += 1.0

                ma_audio = prediction_audio[i].cpu().data.numpy()
                index_ma_audio = np.argmax(ma_audio)
                if index_ma_audio == label[i]:
                    acc_a[label[i]] += 1.0

                ma_visual = prediction_visual[i].cpu().data.numpy()
                index_ma_visual = np.argmax(ma_visual)
                if index_ma_visual == label[i]:
                    acc_v[label[i]] += 1.0

                ma_text = prediction_text[i].cpu().data.numpy()
                index_ma_text = np.argmax(ma_text)
                if index_ma_text == label[i]:
                    acc_t[label[i]] += 1.0

        f1_macro = metrics.f1_score(all_labels, all_preds, average='macro')

    return sum(acc) / sum(num), sum(acc_a) / sum(num), sum(acc_v) / sum(num), sum(acc_t) / sum(num), f1_macro

def main():
    args = get_arguments()

    log_dir = './logs'
    os.makedirs(log_dir, exist_ok=True)

    current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    args_str = f"{args.dataset}_{args.model}_lr{args.learning_rate}_bs{args.batch_size}"
    log_filename = os.path.join(log_dir, f"goal{current_time}_{args_str}.log")

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_filename),
            logging.StreamHandler(sys.stdout)
        ]
    )
    args = get_arguments()
    logging.info(f"{args}")

    if args.dataset == 'VGGSound':
        args.n_classes = 309
    elif args.dataset == 'KineticSound':
        args.n_classes = 31
    elif args.dataset == 'UCF101':
        args.n_classes = 101
    elif args.dataset == 'CREMAD':
        args.n_classes = 6
    elif args.dataset == 'AVE':
        args.n_classes = 28
    elif args.dataset == 'MOSI':
        args.n_classes = 3
    elif args.dataset == 'MELD':
        args.n_classes = 3
    elif args.dataset == 'IEMOCAP':
        args.n_classes = 4

    setup_seed(args.random_seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids

    gpu_ids = list(range(torch.cuda.device_count()))

    device = torch.device('cuda:0')
    model = AVClassifierGOALcmumosi(args)
    model.to(device)

    model = torch.nn.DataParallel(model, device_ids=gpu_ids)
    model.cuda()

    if args.dataset == 'CREMAD':
        train_dataset = AVDataset_CD(mode='train')
        test_dataset = AVDataset_CD(mode='test')
    elif args.dataset == 'KineticSound':
        train_dataset = AVDataset_KS(mode='train')
        test_dataset = AVDataset_KS(mode='test')
    elif args.dataset == 'AVE':
        train_dataset = AVDataset_AVE(mode='train')
        test_dataset = AVDataset_AVE(mode='test')
    elif args.dataset == 'MOSI':
        train_dataset = AVDataset_MOSI(mode='train')
        test_dataset = AVDataset_MOSI(mode='test')
    elif args.dataset == 'MELD':
        train_dataset = AVDataset_MELD(mode='train')
        test_dataset = AVDataset_MELD(mode='test')
    elif args.dataset == 'IEMOCAP':
        train_dataset = AVDataset_IEMOCAP(mode='train')
        test_dataset = AVDataset_IEMOCAP(mode='test')

    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size,
                                  shuffle=True, num_workers=0, pin_memory=False)

    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size,
                                 shuffle=False, num_workers=0)

    if args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-4)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(
            model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08)

    scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, args.lr_decay_ratio)

    print(len(train_dataloader))

    if args.train:
        best_acc = -1

        for epoch in range(args.epochs):

            logging.info('Epoch: {}: '.format(epoch))

            batch_loss = train_epoch(
                    args, epoch, model, device, train_dataloader, optimizer, scheduler, writer=None)

            acc, acc_a, acc_v, acc_t, f1_macro = valid(args, model, device, test_dataloader)

            if acc > best_acc:
                best_acc = float(acc)

                if not os.path.exists(args.ckpt_path):
                    os.mkdir(args.ckpt_path)

                model_name = 'best_model_{}_of_{}_{}_epoch{}_batch{}_lr{}.pth'.format(
                    args.model, args.optimizer,  args.dataset, args.epochs, args.batch_size, args.learning_rate)

                saved_dict = {'saved_epoch': epoch,
                                'acc': acc,
                                'model': model.state_dict(),
                                'optimizer': optimizer.state_dict(),
                                'scheduler': scheduler.state_dict()}

                save_dir = os.path.join(args.ckpt_path, model_name)

                torch.save(saved_dict, save_dir)

                logging.info('The best model has been saved at {}.'.format(save_dir))
                logging.info("Acc: {:.4f}, Acc_a: {:.4f}, Acc_v: {:.4f},Acc_t: {:.4f},f1-macro: {:.4f}".format(acc, acc_a, acc_v,acc_t, f1_macro))
            else:
                logging.info("Loss: {:.4f}, Acc: {:.4f}, Acc_a: {:.4f}, Acc_v: {:.4f},Acc_t: {:.4f},macro:{:.4f},Best Acc: {:.4f}".format(
                    batch_loss, acc, acc_a, acc_v,acc_t, f1_macro, best_acc))


if __name__ == "__main__":
    main()
