import argparse
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from collections import Counter

from utils.utils import setup_seed, weight_init
from dataset.NvGesture import NvGestureDataset
from models.nv_model import exp_b1
from loss.uni_loss import CLLoss_tri
from loss.CA_loss import CA_tri

def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='exp_b1', type=str, help='[exp_b1]')
    parser.add_argument('--dataset', default='nvGesture', type=str,help='nvGesture')
    parser.add_argument('--batch_size', default=4, type=int)
    parser.add_argument('--start_epoch', default=0, type=int)
    parser.add_argument('--epochs', default=150, type=int)
    parser.add_argument('--embed_dim', default=1024, type=int)
    parser.add_argument('--learning_rate', default=0.01, type=float, help='initial learning rate')
    parser.add_argument('--lr_decay_step', default=70, type=int, help='where learning rate decays')
    parser.add_argument('--lr_decay_ratio', default=0.1, type=float, help='decay coefficient')
    parser.add_argument('--train', default=True, help='turn on train mode')
    parser.add_argument('--w_uni', default=0.7, type=float, help='weight of uni loss')
    parser.add_argument('--w_fus', default=3, type=float, help='weight of fus loss')
    parser.add_argument('--w_ca', default=1, type=float, help='weight of CA loss')
     
    parser.add_argument('--ckpt_path', default='ckpt', type=str, help='path to save trained models')
    parser.add_argument('--logs_path', default='logs', type=str, help='path to save tensorboard logs')
    parser.add_argument('--random_seed', default=0, type=int)
    parser.add_argument('--temperature', default=50, type=float, help='loss temperature') 
    parser.add_argument('--gpu', type=int, default=0)  # gpu
    parser.add_argument('--no_cuda', action='store_true', help='Disable CUDA')

    parser.add_argument('--var', default=0.01, type=float) 
    
    return parser.parse_args()

def get_cls_num_list(dataset, num_classes):
    labels=[label for label in dataset.label]
    counter=Counter(labels)
    cls_num_list=[counter[i] for i in range(num_classes)]
    return torch.tensor(cls_num_list)

def main():
    args = get_arguments()
    args.use_cuda = torch.cuda.is_available() and not args.no_cuda
    print(args)

    args.store_name = '_'.join(
        [args.dataset,'bs', str(args.batch_size), 'epochs', str(args.epochs),'lr', str(args.learning_rate),'wd',str(args.lr_decay_step),'temp',str(args.temperature)])
   
    print(args.store_name)
    setup_seed(args.random_seed)  

    if args.dataset == 'nvGesture':
        n_classes = 25
    else:
        raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))

    args.num_classes=n_classes
    device = torch.device('cuda:'+str(args.gpu) if args.use_cuda else 'cpu')
    if args.dataset == 'nvGesture':
        train_dataset = NvGestureDataset( mode='train')
        val_dataset = NvGestureDataset(mode='val')
        test_dataset = NvGestureDataset(mode='test')
    else:
        raise NotImplementedError('Incorrect dataset name {}! '
                                  'Only support NVGesture for now!'.format(args.dataset))

    
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, 
                                  shuffle=True, pin_memory=False)  
    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4,
                                 shuffle=False, pin_memory=False)
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4,
                                 shuffle=False, pin_memory=False)
    
    args.loader_count=len(train_dataloader)
    
    cls_num_list=get_cls_num_list(train_dataset, args.num_classes)
    args.cls_num_list=cls_num_list/torch.sum(cls_num_list)
    
    if args.model=='exp_b1':
        model=exp_b1(args)    
    else:
        raise NotImplementedError('Incorrect model name {}'.format(args.model))
    
    model=model.cuda()
    
    parameters = list(model.parameters())
    optimizer = optim.SGD(parameters, lr=args.learning_rate, momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step, args.lr_decay_ratio)





    if args.train:
        best_acc = 0.0
        
        args.save_path = os.path.join(args.ckpt_path, args.store_name)
        writer_path = os.path.join(args.logs_path, args.store_name)
        
        if not os.path.exists(writer_path):
                os.mkdir(writer_path)
        if not os.path.exists(args.save_path):
                os.mkdir(args.save_path)     
        writer = SummaryWriter(log_dir=writer_path)
        
        for epoch in range(args.start_epoch, args.epochs):
            print('Epoch: {}: '.format(epoch))
                
            acc,batch_loss = train_epoch(args, epoch, model, device, train_dataloader,optimizer, scheduler, args.temperature)
            val_acc, batch_loss_val= valid(args, model, device, val_dataloader, epoch )
            
            
            print('epoch: ', epoch, 'acc: ', acc)


            writer.add_scalar('Evaluation/ Total accuracy', val_acc, epoch)
            
            
            
            if val_acc > best_acc :
                best_acc=val_acc
                torch.save(
                    {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'epoch': epoch
                    },
                    os.path.join(args.save_path, 'best.pt')
                )

    else:
        load_path=os.path.join(args.ckpt_path, args.store_name, 'best.pt')
        load_dict = torch.load(load_path)
        
        model_state_dict = load_dict['model']
        optimizer_state_dict = load_dict['optimizer']
        scheduler_state_dict = load_dict['scheduler']
        
        model.load_state_dict(model_state_dict)
        optimizer.load_state_dict(optimizer_state_dict)
        scheduler.load_state_dict(scheduler_state_dict)
        
        test_acc,_, _,_, _= valid(args, model, device, test_dataloader, epoch )
        print('acc: ', test_acc)
            
        

def train_epoch(args, epoch, model, device,
                dataloader, optimizer, scheduler, temperature):
    
    criterion = nn.CrossEntropyLoss()
    uni_loss = CLLoss_tri(args, temperature)
    ca_loss=CA_tri()
    softmax = nn.Softmax(dim=1)
    
    model.train()
    
    print("Start training ... ")
    
    num = [0.0 for _ in range(args.num_classes)]
    acc = [0.0 for _ in range(args.num_classes)]
    acc1 = [0.0 for _ in range(args.num_classes)]
    acc2 = [0.0 for _ in range(args.num_classes)]
    acc3 = [0.0 for _ in range(args.num_classes)]
    _loss = 0
    
    for step,(rgb,of,depth, label) in enumerate(dataloader):
        rgb = rgb.to(device)
        of = of.to(device)
        depth = depth.to(device)
        label = label.to(device)  
        B = label.shape[0]
        optimizer.zero_grad()
        
        r,o,d, proto, out_r, out_o,out_d, out_f, loss_fus = model(rgb, of, depth, label, B, epoch, train=True)
        
        loss_uni =uni_loss(r,o,d, proto, label)
        loss_r = criterion(out_r, label)
        loss_o = criterion(out_o, label)
        loss_d = criterion(out_d, label)
        

        loss_ca=ca_loss(proto, model.classifier_r.weight, model.classifier_o.weight, model.classifier_d.weight)

        loss_ce=(loss_r + loss_o+ loss_d)/3
        
        loss = loss_ce + args.w_uni*loss_uni + args.w_ca*loss_ca + args.w_fus*loss_fus
        
        loss.backward()
        optimizer.step()

        _loss += loss.item()
        
        prediction_r = softmax(out_r)
        prediction_o = softmax(out_o)
        prediction_d = softmax(out_d)
            
        prediction_f = softmax(out_f)
        for i in range(B):
            ma = np.argmax(prediction_f[i].cpu().data.numpy())
            ma1 = np.argmax(prediction_r[i].cpu().data.numpy())
            ma2 = np.argmax(prediction_o[i].cpu().data.numpy())
            ma3 = np.argmax(prediction_d[i].cpu().data.numpy())
            num[label[i]] += 1.0  # what is label[i]
            if np.asarray(label[i].cpu()) == ma:
                acc[label[i]] += 1.0
            if np.asarray(label[i].cpu()) == ma1:
                acc1[label[i]] += 1.0
            if np.asarray(label[i].cpu()) == ma2:
                acc2[label[i]] += 1.0
            if np.asarray(label[i].cpu()) == ma3:
                acc3[label[i]] += 1.0
        
    scheduler.step()

    return sum(acc) / sum(num),  _loss / len(dataloader)


def valid(args, model, device, dataloader, epoch):
    criterion = nn.CrossEntropyLoss()
    softmax = nn.Softmax(dim=1)


    with torch.no_grad():
        model.eval()
        # TODO: more flexible
        num = [0.0 for _ in range(args.num_classes)]
        acc = [0.0 for _ in range(args.num_classes)]
        acc1 = [0.0 for _ in range(args.num_classes)]
        acc2 = [0.0 for _ in range(args.num_classes)]
        acc3 = [0.0 for _ in range(args.num_classes)]

        _loss_val = 0
        _loss_r_val = 0
        _loss_o_val = 0
        _loss_d_val = 0
    
        for step,(rgb,of,depth, label) in enumerate(dataloader):
            rgb = rgb.to(device)
            of = of.to(device)
            depth = depth.to(device)
            label = label.to(device)
            B = label.shape[0]
            
            r,o,d, proto, out_r, out_o,out_d, out_f = model(rgb, of, depth, None, B, epoch, train=False)
                
                            
            loss_r_val=criterion(out_r, label)
            loss_o_val=criterion(out_o, label)
            loss_d_val=criterion(out_d, label)
            loss_val=loss_r_val+loss_o_val+loss_d_val 
                
            _loss_val += loss_val.item()
            
            _loss_r_val += loss_r_val.item()
            _loss_o_val += loss_o_val.item()
            _loss_d_val += loss_d_val.item()
            
             
            prediction1 = softmax(out_r)
            prediction2 = softmax(out_o)
            prediction3 = softmax(out_d)
            
            prediction = softmax(out_f)
            
            for i in range(B):
                ma = np.argmax(prediction[i].cpu().data.numpy())
                ma1 = np.argmax(prediction1[i].cpu().data.numpy())
                ma2 = np.argmax(prediction2[i].cpu().data.numpy())
                ma3 = np.argmax(prediction3[i].cpu().data.numpy())
                num[label[i]] += 1.0  # what is label[i]
                if np.asarray(label[i].cpu()) == ma:
                    acc[label[i]] += 1.0
                if np.asarray(label[i].cpu()) == ma1:
                    acc1[label[i]] += 1.0
                if np.asarray(label[i].cpu()) == ma2:
                    acc2[label[i]] += 1.0
                if np.asarray(label[i].cpu()) == ma3:
                    acc3[label[i]] += 1.0

    return sum(acc) / sum(num),  _loss_val / len(dataloader)




if __name__ == '__main__':
    main()