from dataset.datasets import AudioDataset
from models import AudioModel, ResNet18_Audio
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import os
import argparse
from dataset.helper import get_num_classes
from utils import create_logger

def train(args):
    # total_epoch = 10
    # dataset_name = 'ks' #[AVE, ks, CREMA-D]
    # num_classes = 32 #[28, 32, 6]
    logger = create_logger(args)

    logger.info(" Load Dataset")
    train_dataset = AudioDataset(train=True, dataset=args.dataset_name)
    test_dataset = AudioDataset(train=False, dataset=args.dataset_name)
    train_data = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
    test_data = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, drop_last=False)
    if args.model == 'vit':
        model = AudioModel(args).cuda()
    elif args.model == 'resnet':
        model = ResNet18_Audio(args).cuda()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    for epoch in range(args.total_epoch):
        train_loss = 0
        num_correct = 0
        model.train()
        for i, (audio, label) in enumerate(train_data):
        
            audio = audio.cuda()
            label = label.cuda()
            output = model(audio)

            loss = criterion(output, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()

            if i%10==0:
                logger.info("Epoch: {},step: {}, loss: {}".format(epoch, i, loss.item()))
            _, pred = output.max(1)
            num_correct += (pred == label).sum().item()
        
        logger.info('Training ACC:'.format(num_correct/len(train_dataset)))

        eval_loss = 0
        eval_correct = 0
        model.eval()
        for i, (audio, label) in enumerate(test_data):
            audio = audio.cuda()
            label = label.cuda()
            output = model(audio)

            loss = criterion(output, label)

            eval_loss += loss.item()
            _, pred = output.max(1)
            eval_correct += (pred==label).sum().item()

        logger.info('**** One epoch has finished ****')
        logger.info('epoch:{}, Train Loss:{:.6f}, Train Acc:{:.6f}, Eval Loss:{:.6f}, Eval Acc:{:.6f}'.format(
            epoch, train_loss/len(train_data), num_correct/len(train_dataset), eval_loss/len(test_data), eval_correct/len(test_dataset)
        ))
        torch.save(model.state_dict(), os.path.join(args.savedir, 'model_epoch{}_acc{:.6f}'.format(epoch, eval_correct/len(test_dataset))))


def get_args():
    parser = argparse.ArgumentParser(description='Training Uni-Audio Model')
    
    parser.add_argument('--model', type=str, default='vit' , choices=['vit','resnet'],help='which model to use')
    parser.add_argument('--lr', type=float, default=1e-5 ,help='learning rate of the model')
    parser.add_argument('--FreezeEnc', action='store_true', help='freeze enc')
    parser.add_argument('--dataset_name', type=str, default='AVE', choices=['AVE', 'ks', 'CREMA-D'],help='using which dataset')
    parser.add_argument('--train_batch_size', type=int, default=64, help='training batch size')
    parser.add_argument('--test_batch_size', type=int, default=64, help='testing batch size')
    parser.add_argument('--total_epoch', type=int, default=10, help='The total num of training epoches')
    parser.add_argument('--savedir', type=str,default='outputs/audio' ,help='the localtion of saving checkpoints')

    args = parser.parse_args()
    args.num_classes = get_num_classes(args.dataset_name)
    args.model_name = args.dataset_name + args.model + '-lr-'+ str(args.lr)
    if args.FreezeEnc:
        args.savedir = args.savedir + '_freeze'
        print('freeze enc')
    args.savedir = os.path.join(args.savedir, args.model_name)


    if os.path.exists(args.savedir):
        print('savedir already here.')
        exit()
    else:
        os.makedirs(args.savedir)

    argsDict = args.__dict__
    with open(os.path.join(args.savedir, 'setting.txt'), 'w') as f:
        f.writelines('------------------- start -------------------' + '\n')
        for arg, value in argsDict.items():
            f.writelines(arg + ' : ' + str(value) + '\n')
        f.writelines('------------------- end -------------------' + '\n')
    return args



if __name__ == "__main__":
    args = get_args()
    train(args)

