from utils.utils import setup_seed
from dataset.av_dataset import AVDataset_mmimdb
import copy
from torch.utils.data import DataLoader
from models.models import AVClassifierregressgoal
from sklearn import metrics
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch
from min_norm_solvers import MinNormSolver
import numpy as np
from tqdm import tqdm
import argparse
import os
import pickle
from operator import mod

import logging
import datetime
import sys


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='MMIMDB', type=str,
                        help='KineticSound, CREMAD, VGGSound, AVE')
    parser.add_argument('--model', default='model', type=str)
    parser.add_argument('--n_classes', default=6, type=int)
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--optimizer', default='sgd',
                        type=str, choices=['sgd', 'adam'])
    parser.add_argument('--learning_rate', default=0.0004, type=float, help='initial learning rate')
    parser.add_argument('--lr_decay_step', default=30, type=int, help='where learning rate decays')
    parser.add_argument('--lr_decay_ratio', default=0.1,
                        type=float, help='decay coefficient')
    parser.add_argument('--ckpt_path', default='log_cd',
                        type=str, help='path to save trained models')
    parser.add_argument('--train', default=True,
                        help='turn on train mode')
    parser.add_argument('--clip_grad', action='store_true',
                        help='turn on train mode')
    parser.add_argument('--use_tensorboard', default=True,
                        type=bool, help='whether to visualize')
    parser.add_argument('--tensorboard_path', default='log_cd',
                        type=str, help='path to save tensorboard logs')
    parser.add_argument('--random_seed', default=0, type=int)
    parser.add_argument('--gpu_ids', default='2',
                        type=str, help='GPU ids')


    return parser.parse_args()



def train_epoch(args, epoch, model, device, dataloader, optimizer, scheduler, writer=None):
    criterion = nn.MSELoss()
    model.train()
    print("Start training ... ")

    _loss = 0

    for step, (images, text1, text2, label) in tqdm(enumerate(dataloader)):
        optimizer.zero_grad()
        images = images.to(device)
        label = label.to(device)
        text = []
        text.append(text1)
        text.append(text2)
        out,out_v,out_t = model(images.float(), text)

        loss_mm = criterion(out, label)

        loss_t=criterion(out_t,label)

        loss_v=criterion(out_v,label)


        loss=loss_mm+loss_t+loss_v

        loss.backward()



        optimizer.step()
        _loss += loss.item()



    return _loss / len(dataloader)


def valid(args, model, device, dataloader):
    with torch.no_grad():
        model.eval()

        all_labels = []
        all_preds = []
        all_preds_t = []
        all_preds_v = []

        for step, (images, text1, text2, label) in tqdm(enumerate(dataloader)):
            images = images.to(device)
            label = label.to(device)
            text = []
            text.append(text1)
            text.append(text2)

            prediction, prediction_visual, prediction_text = model(images.float(), text)

            all_labels.extend(label.cpu().numpy())
            all_preds.extend(prediction.cpu().numpy())
            all_preds_t.extend(prediction_text.cpu().numpy())
            all_preds_v.extend(prediction_visual.cpu().numpy())

        mae = metrics.mean_absolute_error(all_labels, all_preds)
        mse = metrics.mean_squared_error(all_labels, all_preds)

        mae_t = metrics.mean_absolute_error(all_labels, all_preds_t)
        mae_v = metrics.mean_absolute_error(all_labels, all_preds_v)

        logging.info(f"Validation ==> MAE: {mae:.4f}, MSE: {mse:.4f}")
        logging.info(f"Text MAE: {mae_t:.4f}, Visual-only MAE: {mae_v:.4f}")

    return mae, mae_t, mae_v

def main():
    args = get_arguments()

    log_dir = './logs'
    os.makedirs(log_dir, exist_ok=True)

    current_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    args_str = f"{args.dataset}_{args.model}_lr{args.learning_rate}_bs{args.batch_size}"
    log_filename = os.path.join(log_dir, f"goalregress{current_time}_{args_str}.log")

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_filename),
            logging.StreamHandler(sys.stdout)
        ]
    )
    args = get_arguments()
    logging.info(f"{args}")

    if args.dataset == 'VGGSound':
        args.n_classes = 309
    elif args.dataset == 'KineticSound':
        args.n_classes = 31
    elif args.dataset == 'UCF101':
        args.n_classes = 101
    elif args.dataset == 'CREMAD':
        args.n_classes = 6
    elif args.dataset == 'AVE':
        args.n_classes = 28

    setup_seed(args.random_seed)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids

    gpu_ids = list(range(torch.cuda.device_count()))

    device = torch.device('cuda:0')
    model = AVClassifierregressgoal(args)
    model.to(device)

    model = torch.nn.DataParallel(model, device_ids=gpu_ids)
    model.cuda()

    train_dataset = AVDataset_mmimdb(mode='train')
    test_dataset = AVDataset_mmimdb(mode='test')

    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size,
                                  shuffle=True, num_workers=0, pin_memory=False)

    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size,
                                 shuffle=False, num_workers=0)

    if args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=1e-4)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(
            model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08)

    scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, args.lr_decay_ratio)

    print(len(train_dataloader))

    if args.train:
        best_error = float('inf')

        for epoch in range(args.epochs):
            logging.info('Epoch: {}: '.format(epoch))

            batch_loss = train_epoch(
                args, epoch, model, device, train_dataloader, optimizer, scheduler, writer=None)

            mae, mae_t, mae_v = valid(args, model, device, test_dataloader)

            if mae < best_error:
                best_error = mae

                if not os.path.exists(args.ckpt_path):
                    os.mkdir(args.ckpt_path)

                model_name = 'best_model_{}_of_{}_{}_epoch{}_mae{:.4f}.pth'.format(
                    args.model, args.optimizer, args.dataset, epoch, mae)

                saved_dict = {'saved_epoch': epoch,
                              'mae': mae,
                              'model': model.state_dict(),
                              'optimizer': optimizer.state_dict(),
                              'scheduler': scheduler.state_dict()}

                save_dir = os.path.join(args.ckpt_path, model_name)

                torch.save(saved_dict, save_dir)

                logging.info(f'The best model has been saved at {save_dir}.')
                logging.info(f"MAE: {mae:.4f}, MAE_t: {mae_t:.4f}, MAE_v: {mae_v:.4f}")
            else:
                logging.info(f"Loss: {batch_loss:.4f}, MAE: {mae:.4f}, Best MAE: {best_error:.4f}")


if __name__ == "__main__":
    main()
