import argparse
import os
from dataset.datasets import VideoDataset
from dataset.helper import get_num_classes
from utils import create_logger
from models import VideoModel, ResNet18_Video
from torch.utils.data import DataLoader
import torch.nn as nn
import torch



def train(args):
    # dataset_name = 'ks' #[AVE, ks, CREMA-D]
    # num_classes = get_num_classes(dataset_name) #[28, 32, 6]
    # lr=1e-5 #1e-5

    # train_batch_size = 64
    # test_batch_size = 64

    # total_epoch = 10
    logger = create_logger(args)
    logger.info(" Load Dataset:{}".format(str(args.dataset_name)))
    train_dataset = VideoDataset(train=True, dataset=args.dataset_name)
    test_dataset = VideoDataset(train=False, dataset=args.dataset_name)
    train_data = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
    test_data = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, drop_last=False)

    # logger.info(" Load Model, Model Size:{}, Pre-Training Data:{}".format( args.clip_model, args.clip_pretraining_data))
    logger.info(" Loading Model:{}".format(args.model))
    if args.model == 'vit':
        model = VideoModel(args).cuda()
    elif args.model == 'resnet':
        model = ResNet18_Video(args).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    for epoch in range(args.total_epoch):
        train_loss = 0
        num_correct = 0
        model.train()
        for i, (imgs, label) in enumerate(train_data):
        
            imgs = imgs.cuda()
            label = label.cuda()
            output = model(imgs)

            loss = criterion(output, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()

            if i%10==0:
                logger.info("Epoch: {},step: {}, loss: {}".format(epoch, i, loss.item()))
            _, pred = output.max(1)
            num_correct += (pred == label).sum().item()
        
        logger.info('Training ACC:{}'.format(str(num_correct/len(train_dataset))))

        eval_loss = 0
        eval_correct = 0
        model.eval()
        for i, (imgs, label) in enumerate(test_data):
            imgs = imgs.cuda()
            label = label.cuda()
            with torch.no_grad():
                output = model(imgs)

            loss = criterion(output, label)

            eval_loss += loss.item()
            _, pred = output.max(1)
            eval_correct += (pred==label).sum().item()

        logger.info('**** One epoch has finished ****')
        logger.info('epoch:{}, Train Loss:{:.6f}, Train Acc:{:.6f}, Eval Loss:{:.6f}, Eval Acc:{:.6f}'.format(
            epoch, train_loss/len(train_data), num_correct/len(train_dataset), eval_loss/len(test_data), eval_correct/len(test_dataset)
        ))
        torch.save(model.state_dict(), os.path.join(args.savedir, 'model_epoch{}_acc{:.6f}'.format(epoch, eval_correct/len(test_dataset))))


def get_args():
    parser = argparse.ArgumentParser(description='Training Uni-Video Model')
    parser.add_argument('--lr', type=float, default=1e-5 ,help='learning rate of the model')
    parser.add_argument('--model', type=str, default='vit' , choices=['vit','resnet'],help='which model to use')
    parser.add_argument('--FreezeEnc', action='store_true', help='freeze enc')

    parser.add_argument('--dataset_name', type=str, default='AVE', choices=['AVE', 'ks', 'CREMA-D'],help='using which dataset')
    parser.add_argument('--clip_model', type=str, default='ViT-B-16', choices=['ViT-L-14', 'ViT-B-16'],help='using which size of model')
    parser.add_argument('--clip_pretraining_data', type=str, default='laion2b_s34b_b88k', 
                        choices=['datacomp_xl_s13b_b90k', 'laion2b_s34b_b88k'],help='using which dataset for pre-training')
    parser.add_argument('--train_batch_size', type=int, default=64, help='training batch size')
    parser.add_argument('--test_batch_size', type=int, default=64, help='testing batch size')
    parser.add_argument('--total_epoch', type=int, default=10, help='The total num of training epoches')
    parser.add_argument('--savedir', type=str,default='outputs/visual' ,help='the localtion of saving checkpoints')

    args = parser.parse_args()
    args.num_classes = get_num_classes(args.dataset_name)
    if args.clip_model == 'ViT-L-14':
        args.clip_pretraining_data = 'datacomp_xl_s13b_b90k'
    args.model_name = args.dataset_name + '-' + args.clip_model + '-' + args.clip_pretraining_data + '-lr-'+ str(args.lr) + '-bs-' +  str(args.train_batch_size)
    if args.FreezeEnc:
        args.savedir = args.savedir + '_freeze'
        print('freeze enc')
    args.savedir = os.path.join(args.savedir, args.model_name)
    
    if os.path.exists(args.savedir):
        print('savedir already here.')
        exit()
    else:
        os.makedirs(args.savedir)

    argsDict = args.__dict__
    with open(os.path.join(args.savedir, 'setting.txt'), 'w') as f:
        f.writelines('------------------- start -------------------' + '\n')
        for arg, value in argsDict.items():
            f.writelines(arg + ' : ' + str(value) + '\n')
        f.writelines('------------------- end -------------------' + '\n')
    return args



if __name__ == "__main__":
    args = get_args()
    train(args)
