import argparse
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from collections import Counter

from dataset.Dataset import CramedDataset
from dataset.Dataset import KineticDataset
from utils.utils import setup_seed, weight_init
from models.model import  exp_b1
from loss.uni_loss import CLLoss
from loss.CA_loss import CA


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='exp_b1', type=str)
    parser.add_argument('--audio_path', default='C:/Users/user/Downloads/CREMA-D/AudioWAV', type=str)
    parser.add_argument('--visual_path', default='C:/Users/user/Downloads/CREMA-D', type=str)
    parser.add_argument('--dataset', default='KineticSound', type=str,
                        help='CREMAD, KineticSound')
    parser.add_argument('--batch_size', default=16, type=int)
    parser.add_argument('--start_epoch', default=0, type=int)
    parser.add_argument('--epochs', default=150, type=int)
    parser.add_argument('--embed_dim', default=512, type=int)
    parser.add_argument('--fps', default=1, type=int, help='Extract how many frames in a second')
    parser.add_argument('--num_frame', default=3, type=int, help='use how many frames for train')
    parser.add_argument('--w_uni', default=0.7, type=float, help='weight of uni loss')
    parser.add_argument('--w_fus', default=3, type=float, help='weight of fus loss')
    parser.add_argument('--w_ca', default=1, type=float, help='weight of CA loss')
     
    parser.add_argument('--learning_rate', default=0.01, type=float, help='initial learning rate')
    parser.add_argument('--lr_decay_step', default=70, type=int, help='where learning rate decays')
    parser.add_argument('--lr_decay_ratio', default=0.1, type=float, help='decay coefficient')
    parser.add_argument('--ckpt_path', default='ckpt', type=str, help='path to save trained models')
    parser.add_argument('--logs_path', default='logs', type=str, help='path to save tensorboard logs')
    parser.add_argument('--train', default=True, help='turn on train mode')
    parser.add_argument('--random_seed', default=0, type=int)
    parser.add_argument('--temperature', default=50, type=float, help='loss temperature') # CremaD & Kinetic: 50
    parser.add_argument('--gpu', type=int, default=0)  # gpu
    parser.add_argument('--no_cuda', action='store_true', help='Disable CUDA')
    parser.add_argument('--var', default=0.01, type=float) 
    
    return parser.parse_args()

def get_cls_num_list(dataset, num_classes):
    labels=[label for label in dataset.label]
    counter=Counter(labels)
    cls_num_list=[counter[i] for i in range(num_classes)]
    return torch.tensor(cls_num_list)

def main():
    args = get_arguments()
    args.use_cuda = torch.cuda.is_available() and not args.no_cuda
    print(args)

    args.store_name = '_'.join(
        [args.dataset,'bs', str(args.batch_size), 'epochs', str(args.epochs),'lr', str(args.learning_rate),'wd',str(args.lr_decay_step),'temp',str(args.temperature)])
    
    print(args.store_name)
    setup_seed(args.random_seed)   

    if args.dataset == 'CREMAD':
        n_classes = 6
    elif args.dataset == 'KineticSound':
        n_classes = 31
    else:
        raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))

    args.num_classes=n_classes
    device = torch.device('cuda:'+str(args.gpu) if args.use_cuda else 'cpu')
    args.device=device
    if args.dataset == 'CREMAD':
        train_dataset = CramedDataset(args, mode='train')
        val_dataset = CramedDataset(args, mode='val')
        test_dataset = CramedDataset(args, mode='test')
    elif args.dataset == 'KineticSound':
        train_dataset = KineticDataset(args, mode='train')
        val_dataset = KineticDataset(args, mode='val')
        test_dataset = KineticDataset(args, mode='test')
    else:
        raise NotImplementedError('Incorrect dataset name {}! '
                                  'Only support KineticSound and CREMA-D for now!'.format(args.dataset))

    cls_num_list=get_cls_num_list(train_dataset, args.num_classes)
    args.cls_num_list=cls_num_list/torch.sum(cls_num_list)
    
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, 
                                  shuffle=True, pin_memory=False)  
    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4,
                                 shuffle=False, pin_memory=False)
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4,
                                shuffle=False, pin_memory=False)
    
    if args.model=='exp_b1':
        model=exp_b1(args)    
    else:
        raise NotImplementedError('Incorrect model name {}'.format(args.model))
    
    model=model.cuda()
    
    parameters = list(model.parameters())
    optimizer = optim.SGD(parameters, lr=args.learning_rate, momentum=0.9, weight_decay=1e-4)
    
    
    scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, args.lr_decay_ratio)

    if args.start_epoch!=0:
        load_path=os.path.join(args.ckpt_path, args.store_name, 'best.pt')
    
        load_dict = torch.load(load_path)
        
        model_state_dict = load_dict['model']
        optimizer_state_dict = load_dict['optimizer']
        scheduler_state_dict = load_dict['scheduler']
        
        model.load_state_dict(model_state_dict)
        optimizer.load_state_dict(optimizer_state_dict)
        scheduler.load_state_dict(scheduler_state_dict)



    if args.train:
        best_acc = 0.0
        best_acc_a = 0.0
        best_acc_v = 0.0
        
        args.save_path = os.path.join(args.ckpt_path, args.store_name)
        writer_path = os.path.join(args.logs_path, args.store_name)
        
        if not os.path.exists(writer_path):
                os.mkdir(writer_path)
        if not os.path.exists(args.save_path):
                os.mkdir(args.save_path)     
        writer = SummaryWriter(log_dir=writer_path)
        
        for epoch in range(args.start_epoch, args.epochs):
            print('Epoch: {}: '.format(epoch))
                
            acc,acc_a, acc_v,batch_loss = train_epoch(args, epoch, model, device, train_dataloader,optimizer, scheduler, args.temperature)
            val_acc,val_acc_a, val_acc_v, batch_loss_val,  = valid(args, model, device, val_dataloader, epoch )
            
            
            print('epoch: ', epoch, 'acc: ', acc)


            writer.add_scalar('Evaluation/ Total accuracy', val_acc, epoch)
            
            
            if val_acc > best_acc :
                best_acc=val_acc
                torch.save(
                    {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'epoch': epoch
                    },
                    os.path.join(args.save_path, 'best.pt')
                )
    else:
        load_path=os.path.join(args.ckpt_path, args.store_name, 'best.pt')
    
        load_dict = torch.load(load_path)
        
        model_state_dict = load_dict['model']
        optimizer_state_dict = load_dict['optimizer']
        scheduler_state_dict = load_dict['scheduler']
        
        model.load_state_dict(model_state_dict)
        optimizer.load_state_dict(optimizer_state_dict)
        scheduler.load_state_dict(scheduler_state_dict)
        
        test_acc,_, _,_, _= valid(args, model, device, test_dataloader, epoch )
        print('acc: ', test_acc)
            

def train_epoch(args, epoch, model, device,
                dataloader, optimizer, scheduler, temperature):
    
    criterion = nn.CrossEntropyLoss()
    uni_loss = CLLoss(args, temperature)
    ca_loss=CA()
    softmax = nn.Softmax(dim=1)
    if args.dataset == 'CREMAD':
        n_classes = 6
    elif args.dataset == 'KineticSound':
        n_classes = 31
    else:
        raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))
    
    model.train()
    
    print("Start training ... ")
    
    num = [0.0 for _ in range(n_classes)]
    acc = [0.0 for _ in range(n_classes)]
    acc1 = [0.0 for _ in range(n_classes)]
    acc2 = [0.0 for _ in range(n_classes)]
    _loss = 0
    
    for step, (spec, image, label) in enumerate(dataloader):
        spec = spec.to(device) 
        image = image.to(device)  
        label = label.to(device)  
        B = label.shape[0]
        optimizer.zero_grad()
        
        a,v, proto, out_a, out_v, out_f, loss_fus = model(spec.unsqueeze(1).float(), image.float(), label, B, epoch, train=True)
        
        loss_uni=uni_loss(a,v, proto, label)
        
        loss_a = criterion(out_a, label)
        loss_v = criterion(out_v, label)
        loss_ca=ca_loss(proto, model.classifier_a.weight, model.classifier_v.weight)
        loss_ce=(loss_a + loss_v)/2

        loss = loss_ce + args.w_uni*loss_uni + args.w_ca*loss_ca + args.w_fus*loss_fus
        
        loss.backward()
        optimizer.step()

        _loss += loss.item()
        
        prediction_a = softmax(out_a)
        prediction_v = softmax(out_v)
            
        prediction_f = softmax(out_f)
        for i in range(image.shape[0]):
            ma = np.argmax(prediction_f[i].cpu().data.numpy())
            ma1 = np.argmax(prediction_a[i].cpu().data.numpy())
            ma2 = np.argmax(prediction_v[i].cpu().data.numpy())
            num[label[i]] += 1.0  # what is label[i]
            if np.asarray(label[i].cpu()) == ma:
                acc[label[i]] += 1.0
            if np.asarray(label[i].cpu()) == ma1:
                acc1[label[i]] += 1.0
            if np.asarray(label[i].cpu()) == ma2:
                acc2[label[i]] += 1.0
        
    scheduler.step()

    return sum(acc) / sum(num), sum(acc1) / sum(num), sum(acc2) / sum(num),  _loss / len(dataloader)


def valid(args, model, device, dataloader, epoch):
    criterion = nn.CrossEntropyLoss()
    softmax = nn.Softmax(dim=1)

    if args.dataset == 'CREMAD':
        n_classes = 6
    elif args.dataset == 'KineticSound':
        n_classes = 31
        
    else:
        raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))

    with torch.no_grad():
        model.eval()
        # TODO: more flexible
        num = [0.0 for _ in range(n_classes)]
        acc = [0.0 for _ in range(n_classes)]
        acc1 = [0.0 for _ in range(n_classes)]
        acc2 = [0.0 for _ in range(n_classes)]

        _loss_val = 0
        
        for step, (spec, image, label) in enumerate(dataloader):
            
            spec = spec.to(device)
            image = image.to(device)
            label = label.to(device)
            B = label.shape[0]
            
            a,v, proto, out1, out2, out = model(spec.unsqueeze(1).float(), image.float(), None, B, epoch, train=False)
                
                            
            loss_a_val=criterion(out1, label)
            loss_v_val=criterion(out2, label)
            loss_val=loss_a_val+loss_v_val 
                
            _loss_val += loss_val.item()
            
            
             
            prediction1 = softmax(out1)
            prediction2 = softmax(out2)
            
            prediction = softmax(out)
            
            for i in range(image.shape[0]):
                ma = np.argmax(prediction[i].cpu().data.numpy())
                ma1 = np.argmax(prediction1[i].cpu().data.numpy())
                ma2 = np.argmax(prediction2[i].cpu().data.numpy())
                num[label[i]] += 1.0  # what is label[i]
                if np.asarray(label[i].cpu()) == ma:
                    acc[label[i]] += 1.0
                if np.asarray(label[i].cpu()) == ma1:
                    acc1[label[i]] += 1.0
                if np.asarray(label[i].cpu()) == ma2:
                    acc2[label[i]] += 1.0


    return sum(acc) / sum(num), sum(acc1) / sum(num), sum(acc2) / sum(num), _loss_val / len(dataloader)




if __name__ == '__main__':
    main()