import argparse
import os
from dataset.datasets import MultimodalDataset
from dataset.helper import get_num_classes
from utils import create_logger, freeze_encoders, load_uni_enc_ckpt
from models import UMEMMLoRAModel, ResNetUMEMMLoRAModel
from torch.utils.data import DataLoader
import torch.nn as nn
import torch



def train(args):
    # dataset_name = 'ks' #[AVE, ks, CREMA-D]
    # num_classes = get_num_classes(dataset_name) #[28, 32, 6]
    # lr=1e-5 #1e-5

    # train_batch_size = 64
    # test_batch_size = 64

    # total_epoch = 10
    logger = create_logger(args)
    logger.info(" Load Dataset:{}".format(str(args.dataset_name)))
    train_dataset = MultimodalDataset(train=True, dataset=args.dataset_name)
    test_dataset = MultimodalDataset(train=False, dataset=args.dataset_name)
    train_data = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
    test_data = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, drop_last=False)

    logger.info(" Loading Model")
    if args.model == 'vit':
        model = UMEMMLoRAModel(args).cuda()
    elif args.model == 'resnet':
        model = ResNetUMEMMLoRAModel(args).cuda()
    else:
        print("model name error")
        exit()
    if args.loadUniEnc:
        load_uni_enc_ckpt(args, model)
    if args.freeze_encoder:
        freeze_encoders(model)
    model.lora_the_model()
    if not args.loadUniEnc:
        for n, p in model.named_parameters():
            if 'clf' in n:
                p.requires_grad = True
        
    model.cuda()
    for n, p in model.named_parameters():
        if p.requires_grad:
            logger.info("trainable parameters:{}".format(n))


    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)


    for epoch in range(args.total_epoch):
        train_loss = 0
        num_mm_correct = 0
        num_audio_correct = 0
        num_video_correct = 0
        model.train()
        for i, (audios, imgs, label) in enumerate(train_data):
            audios = audios.cuda()
            imgs = imgs.cuda()
            label = label.cuda()
            a_output, v_output, mm_output = model(audios, imgs)

            mmloss = criterion(mm_output, label)
            # uniloss = criterion(a_output, label) + criterion(v_output, label)
            # loss = mmloss + uniloss
            loss = mmloss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += mmloss.item()

            if i%10==0:
                logger.info("Epoch: {},step: {}, loss: {}".format(epoch, i, mmloss.item()))
            _, pred = mm_output.max(1)
            num_mm_correct += (pred == label).sum().item()

            _, pred = a_output.max(1)
            num_audio_correct += (pred == label).sum().item()
            _, pred = v_output.max(1)
            num_video_correct += (pred == label).sum().item()

        logger.info('MM Training ACC:{}, Audio Training ACC:{}, Video Training ACC:{}'.format(str(num_mm_correct/len(train_dataset)),
                                                                                              str(num_audio_correct/len(train_dataset)),
                                                                                              str(num_video_correct/len(train_dataset)))
                                                                                              )

        eval_mm_correct = 0
        eval_audio_correct = 0
        eval_video_correct = 0
        model.eval()
        for i, (audios, imgs, label) in enumerate(test_data):
            audios = audios.cuda()
            imgs = imgs.cuda()
            label = label.cuda()
            with torch.no_grad():
                a_output, v_output, mm_output = model(audios,imgs)

            _, pred = mm_output.max(1)
            eval_mm_correct += (pred==label).sum().item()

            _, pred = a_output.max(1)
            eval_audio_correct += (pred==label).sum().item()
            
            _, pred = v_output.max(1)
            eval_video_correct += (pred==label).sum().item()



        logger.info('**** One epoch has finished ****')
        logger.info('epoch:{}, Train MM ACC:{:.6f}, Eval MM Acc:{:.6f}, Eval Audio Acc:{:.6f}, Eval Video Acc:{:.6f}'.format(
            epoch, num_mm_correct/len(train_dataset), eval_mm_correct/len(test_dataset), eval_audio_correct/len(test_dataset), eval_video_correct/len(test_dataset)
        ))
        torch.save(model.state_dict(), os.path.join(args.savedir, 'model_epoch{}_acc{:.6f}'.format(epoch, eval_mm_correct/len(test_dataset))))


def get_args():
    parser = argparse.ArgumentParser(description='Training Multi-modal Model')
    parser.add_argument('--lora_r', type=int, default=1 ,help='lora rank')
    parser.add_argument('--lora_alpha', type=int, default=4 ,help='lora alpha')
    parser.add_argument('--lora_modal', type=str, default='a', choices=['a', 'v', 'mm'],help='using which mode for lora')
    
    parser.add_argument('--model', type=str, default='vit', choices=['vit', 'resnet'],help='using which model')

    parser.add_argument('--lr', type=float, default=1e-5 ,help='learning rate of the model')
    parser.add_argument('--modalDrop', type=float, default=0.0 ,help='the prob of modality drop')
    
    parser.add_argument('--loadUniEnc', action='store_true', help='useing uni-modal finetuned models or not')
    parser.add_argument('--freeze_encoder', action='store_true', help='freeze_encoder or not')
    
    parser.add_argument('--dataset_name', type=str, default='AVE', choices=['AVE', 'ks', 'CREMA-D'],help='using which dataset')
    parser.add_argument('--clip_model', type=str, default='ViT-B-16', choices=['ViT-L-14', 'ViT-B-16'],help='using which size of model')
    parser.add_argument('--clip_pretraining_data', type=str, default='laion2b_s34b_b88k', 
                        choices=['datacomp_xl_s13b_b90k', 'laion2b_s34b_b88k'],help='using which dataset for pre-training')
    parser.add_argument('--train_batch_size', type=int, default=64, help='training batch size')
    parser.add_argument('--test_batch_size', type=int, default=16, help='testing batch size')
    parser.add_argument('--total_epoch', type=int, default=30, help='The total num of training epoches')
    parser.add_argument('--savedir', type=str,default='outputs/mutimodal_ume_lora' ,help='the localtion of saving checkpoints')

    args = parser.parse_args()
    args.num_classes = get_num_classes(args.dataset_name)
    if args.lora_modal == 'a':
        prefix = 'NewAudiolora_'
    elif args.lora_modal == 'v':
        prefix = 'NewVideolora_'
    elif args.lora_modal == 'mm':
        prefix = 'MultiModallora_'
        
    else:
        print(args.lora_modal)
        exit()

    
    # if args.model == 'resnet':
    #     args.savedir = 'outputs/mutimodal_ume_lora_resnet'
    # if args.dataset_name =='CREMA-D':
    #     args.savedir = 'outputs/mutimodal_ume_lora_vit_CREMA-D'
    prefix = "0921-" + prefix
    args.model_name = prefix + args.model + '-' + args.dataset_name + '-loadUniEnc-' + str(args.loadUniEnc) + '-freezeeEnc-' + str(args.freeze_encoder) \
        + '-bs' +  str(args.train_batch_size) + '-lr' + str(args.lr) + '-totalepoch' + str(args.total_epoch) \
        + '-modalDrop' + str(args.modalDrop) + '-lora-rank-' + str(args.lora_r) + '-alpha-' + str(args.lora_alpha)
    args.savedir = os.path.join(args.savedir, args.model_name)
    print(args)
    if os.path.exists(args.savedir):
        print('savedir already here.', args.savedir)
        exit()
    else:
        os.makedirs(args.savedir)

    argsDict = args.__dict__
    with open(os.path.join(args.savedir, 'setting.txt'), 'w') as f:
        f.writelines('------------------- start -------------------' + '\n')
        for arg, value in argsDict.items():
            f.writelines(arg + ' : ' + str(value) + '\n')
        f.writelines('------------------- end -------------------' + '\n')
    return args



if __name__ == "__main__":
    import warnings

    warnings.filterwarnings("ignore")
    args = get_args()
    train(args)
