import argparse
import os
from dataset.datasets import MultimodalDataset
from dataset.helper import get_num_classes
from utils import create_logger, freeze_encoders, load_uni_enc_ckpt
from models import MMmodel, ResNetMMmodel, MultiTaskMMmodel
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import random


def train(args):
    # dataset_name = 'ks' #[AVE, ks, CREMA-D]
    # num_classes = get_num_classes(dataset_name) #[28, 32, 6]
    # lr=1e-5 #1e-5

    # train_batch_size = 64
    # test_batch_size = 64

    # total_epoch = 10
    logger = create_logger(args)
    logger.info(" Load Dataset:{}".format(str(args.dataset_name)))
    train_dataset = MultimodalDataset(train=True, dataset=args.dataset_name)
    test_dataset = MultimodalDataset(train=False, dataset=args.dataset_name)
    train_data = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
    test_data = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, drop_last=False)

    logger.info(" Loading Model")
    if args.alg == '3ce':
        model = MultiTaskMMmodel(args).cuda()
    elif args.alg == 'baseline' or args.alg == 'ours':
        model = MMmodel(args).cuda()
    else:
        print('alg name error')
        exit()
    if args.loadUniEnc:
        load_uni_enc_ckpt(args, model)
    if args.freeze_encoder:
        freeze_encoders(model)
    
    criterion = nn.CrossEntropyLoss()

    if args.alg == 'baseline' or args.alg == 'ours':
        param_groups = [{'params': model.audio_model.parameters(), 'lr': args.lr},
                        {'params': model.video_model.parameters(), 'lr': args.lr},
                        {'params': model.a_clf.parameters(), 'lr': 1e-3},
                        {'params': model.v_clf.parameters(), 'lr': 1e-3}]
        optimizer = torch.optim.AdamW(param_groups)
    else:
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    for epoch in range(args.total_epoch):
        train_loss = 0
        num_mm_correct = 0
        num_audio_correct = 0
        num_video_correct = 0
        model.train()
        for i, (audios, imgs, label) in enumerate(train_data):
            audios = audios.cuda()
            imgs = imgs.cuda()
            B, _, _, _, _ = audios.shape

            if args.modalDrop != 0:
                for j in range(B):
                    rand_drop = random.random()
                    if rand_drop < args.modalDrop:
                        audios[j] = torch.zeros_like(audios[j]).cuda()

                    elif rand_drop < args.modalDrop * 2:
                        imgs[j] = torch.zeros_like(imgs[j]).cuda()

            label = label.cuda()
            a_output, v_output, mm_output = model(audios, imgs)

            mmloss = criterion(mm_output, label)
            uniloss = criterion(a_output, label) + criterion(v_output, label)
            loss = mmloss + uniloss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += mmloss.item()

            if i%10==0:
                logger.info("Epoch: {},step: {}, loss: {}".format(epoch, i, mmloss.item()))
            _, pred = mm_output.max(1)
            num_mm_correct += (pred == label).sum().item()

            _, pred = a_output.max(1)
            num_audio_correct += (pred == label).sum().item()
            _, pred = v_output.max(1)
            num_video_correct += (pred == label).sum().item()

        
        logger.info('MM Training ACC:{}, Audio Training ACC:{}, Video Training ACC:{}'.format(str(num_mm_correct/len(train_dataset)),
                                                                                              str(num_audio_correct/len(train_dataset)),
                                                                                              str(num_video_correct/len(train_dataset)))
                                                                                              )

        eval_mm_correct = 0
        eval_audio_correct = 0
        eval_video_correct = 0
        model.eval()
        for i, (audios, imgs, label) in enumerate(test_data):
            audios = audios.cuda()
            imgs = imgs.cuda()
            label = label.cuda()
            with torch.no_grad():
                a_output, v_output, mm_output = model(audios,imgs)

            _, pred = mm_output.max(1)
            eval_mm_correct += (pred==label).sum().item()

            _, pred = a_output.max(1)
            eval_audio_correct += (pred==label).sum().item()
            
            _, pred = v_output.max(1)
            eval_video_correct += (pred==label).sum().item()

        logger.info('**** One epoch has finished ****')
        logger.info('epoch:{}, Train MM ACC:{:.6f}, Eval MM Acc:{:.6f}, Eval Audio Acc:{:.6f}, Eval Video Acc:{:.6f}'.format(
            epoch, num_mm_correct/len(train_dataset), eval_mm_correct/len(test_dataset), eval_audio_correct/len(test_dataset), eval_video_correct/len(test_dataset)
        ))

        eval_mm_correct = 0
        model.eval()
        for i, (audios, imgs, label) in enumerate(test_data):
            audios = audios.cuda()
            imgs = imgs.cuda()
            imgs = torch.zeros_like(imgs).cuda()
            label = label.cuda()
            with torch.no_grad():
                a_output, v_output, mm_output = model(audios,imgs)

            _, pred = mm_output.max(1)
            eval_mm_correct += (pred==label).sum().item()

        logger.info('epoch:{}, missing image Eval MM Acc:{:.6f}'.format(
            epoch,eval_mm_correct/len(test_dataset)
        ))

        eval_mm_correct = 0
        model.eval()
        for i, (audios, imgs, label) in enumerate(test_data):
            audios = audios.cuda()
            audios = torch.zeros_like(audios).cuda()

            imgs = imgs.cuda()
            label = label.cuda()
            with torch.no_grad():
                a_output, v_output, mm_output = model(audios,imgs)

            _, pred = mm_output.max(1)
            eval_mm_correct += (pred==label).sum().item()

        logger.info('epoch:{}, missing audio Eval MM Acc:{:.6f}'.format(
            epoch,eval_mm_correct/len(test_dataset)
        ))
        

        # torch.save(model.state_dict(), os.path.join(args.savedir, 'model_epoch{}_acc{:.6f}'.format(epoch, eval_mm_correct/len(test_dataset))))


def get_args():
    parser = argparse.ArgumentParser(description='Training Multi-modal Model')
    parser.add_argument('--alg', type=str, default='baseline', choices=['3ce', 'ours','baseline'],help='using which model')
    parser.add_argument('--model', type=str, default='vit', choices=['vit', 'resnet'],help='using which model')

    parser.add_argument('--lr', type=float, default=1e-5 ,help='learning rate of the model')
    parser.add_argument('--modalDrop', type=float, default=0.0 ,help='the prob of modality drop')
    
    parser.add_argument('--loadUniEnc', action='store_true', help='useing uni-modal finetuned models or not')
    parser.add_argument('--freeze_encoder', action='store_true', help='freeze_encoder or not')
    
    parser.add_argument('--dataset_name', type=str, default='AVE', choices=['AVE', 'ks', 'CREMA-D'],help='using which dataset')
    parser.add_argument('--clip_model', type=str, default='ViT-B-16', choices=['ViT-L-14', 'ViT-B-16'],help='using which size of model')
    parser.add_argument('--clip_pretraining_data', type=str, default='laion2b_s34b_b88k', 
                        choices=['datacomp_xl_s13b_b90k', 'laion2b_s34b_b88k'],help='using which dataset for pre-training')
    parser.add_argument('--train_batch_size', type=int, default=64, help='training batch size')
    parser.add_argument('--test_batch_size', type=int, default=16, help='testing batch size')
    parser.add_argument('--total_epoch', type=int, default=30, help='The total num of training epoches')
    parser.add_argument('--savedir', type=str,default='outputs/missing' ,help='the localtion of saving checkpoints')

    args = parser.parse_args()
    args.num_classes = get_num_classes(args.dataset_name)
    args.model_name = args.alg + "-" + args.model + "-" + args.dataset_name + '-loadUniEnc-' + str(args.loadUniEnc) + '-freezeeEnc-' + str(args.freeze_encoder) \
        + '-bs' +  str(args.train_batch_size) + '-lr' + str(args.lr) + '-totalepoch' + str(args.total_epoch) \
        + '-modalDrop' + str(args.modalDrop)
    args.savedir = os.path.join(args.savedir, args.model_name)

    if os.path.exists(args.savedir):
        print('savedir already here.', args.savedir)
        exit()
    else:
        os.makedirs(args.savedir)

    argsDict = args.__dict__
    with open(os.path.join(args.savedir, 'setting.txt'), 'w') as f:
        f.writelines('------------------- start -------------------' + '\n')
        for arg, value in argsDict.items():
            f.writelines(arg + ' : ' + str(value) + '\n')
        f.writelines('------------------- end -------------------' + '\n')
    return args



if __name__ == "__main__":
    args = get_args()
    train(args)
