import os
import time
import numpy as np
import torch
from config import params
from torch import nn, optim
from torch.utils.data import DataLoader
if params["data"]==0:
    from lib.dataloader_uvsd import get_test_loader, get_train_loader, get_val_loader
elif params["data"]==1:
    from lib.dataloader_rsl import get_test_loader, get_train_loader, get_val_loader
from sklearn.metrics import precision_score, recall_score, f1_score,classification_report,confusion_matrix
import torch.backends.cudnn as cudnn
# from lib.dataset import VideoDataset
from lib import model
from tensorboardX import SummaryWriter

import os
# os.environ['CUDA_VISIBLE_DEVICES']='4,5,6,7'

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def metrics(y_true,y_pred):

    acc = sum(y_true==y_pred)/len(y_pred)
    f1 = f1_score(y_true, y_pred, average='macro')
    precision=precision_score(y_true, y_pred,average='macro')
    recall=recall_score(y_true, y_pred, average='macro')
    print (confusion_matrix(y_true, y_pred))
    return acc,f1,precision,recall

def train(model, train_dataloader, epoch, criterion, optimizer, writer, best_train_accuracy,best_train_f1):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    sum_y, sum_target = np.asarray([]),np.asarray([])

    model.train()
    end = time.time()
    for step, (res) in enumerate(train_dataloader):
    # for step, (videos, faces, label,gender) in enumerate(train_dataloader):
        
        data_time.update(time.time() - end)
        model.cuda()
        videos = res[1].cuda()
        label = res[0].cuda()
        emo=res[2].cuda()
        gpt=res[3].cuda()
        outputs = model(videos,emo,gpt)
        print("outputs",outputs.shape)
        print("label",label.shape)
        loss = criterion(outputs, label)

        y = torch.argmax(outputs,dim = 1).cpu().numpy()
        sum_y =  np.concatenate((sum_y,y),axis=0)
        sum_target = np.concatenate((sum_target,label.cpu().numpy()),axis=0)
        losses.update(loss.item(), videos.size(0))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_time.update(time.time() - end)
        end = time.time()
        acc, f1,prec,recall = metrics(sum_target,sum_y) 
        if acc>best_train_accuracy or f1>best_train_f1:
            print("Best result")
            best_train_accuracy=acc
            best_train_f1=f1
        if (step+1) % params['display'] == 0:
            print('-------------------------------------------------------')
            for param in optimizer.param_groups:
                print('lr: ', param['lr'])
            print_string = 'Epoch: [{0}][{1}/{2}]'.format(epoch, step+1, len(train_dataloader))
            print(print_string)
            print_string = 'data_time: {data_time:.3f}, batch time: {batch_time:.3f}'.format(
                data_time=data_time.val,
                batch_time=batch_time.val)
            print(print_string)
            print_string = 'loss: {loss:.5f}'.format(loss=losses.avg)
            print(print_string)
            acc, f1,prec,recall = metrics(sum_target,sum_y)
            print_string = 'acc:{acc:.5f}, f1:{f1:.5f},precision:{prec:.5f},recall:{recall:.5f}'.format(acc=acc, f1=f1,prec=prec, recall=recall)
            print(print_string)
    acc, f1,prec,recall = metrics(sum_target,sum_y)
    writer.add_scalar('train_loss_epoch', losses.avg, epoch)
    writer.add_scalar('train_acc_epoch', acc, epoch)
    writer.add_scalar('train_f1_epoch', f1, epoch)
    return best_train_accuracy,best_train_f1

def validation(model, val_dataloader, epoch, criterion, optimizer, writer,best_test_accuracy,best_test_f1):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    sum_y, sum_target = np.asarray([]),np.asarray([])
    model.eval()

    end = time.time()
    with torch.no_grad():
        # for step, (videos, faces, label,gender) in enumerate(val_dataloader):
        for step, (res) in enumerate(val_dataloader):
            data_time.update(time.time() - end)           
            model.cuda()
            videos = res[1].cuda()
            label = res[0].cuda()
            emo = res[2].cuda()
            gpt=res[3].cuda()
            outputs = model(videos,emo,gpt)
            loss = criterion(outputs, label)

            y = torch.argmax(outputs,dim = 1).cpu().numpy()
            sum_y =  np.concatenate((sum_y,y),axis=0)
            sum_target = np.concatenate((sum_target,label.cpu().numpy()),axis=0)
            losses.update(loss.item(), videos.size(0))
            batch_time.update(time.time() - end)
            end = time.time()
            acc, f1,prec,recall = metrics(sum_target,sum_y) 
            if acc>best_test_accuracy or f1>best_test_f1:
                print("Best result")
                best_test_accuracy=acc
                best_test_f1=f1
            if (step + 1) % params['display'] == 0:
                print('----validation----')
                print_string = 'Epoch: [{0}][{1}/{2}]'.format(epoch, step + 1, len(val_dataloader))
                print(print_string)
                print_string = 'data_time: {data_time:.3f}, batch time: {batch_time:.3f}'.format(
                    data_time=data_time.val,
                    batch_time=batch_time.val)
                print(print_string)
                print_string = 'loss: {loss:.5f}'.format(loss=losses.avg)
                print(print_string)
                acc, f1,prec,recall = metrics(sum_target,sum_y)
                print_string = 'acc:{acc:.5f}, f1:{f1:.5f},precision:{prec:.5f},recall:{recall:.5f}'.format(acc=acc, f1=f1,prec=prec, recall=recall)
                print(print_string)
    writer.add_scalar('val_loss_epoch', losses.avg, epoch)
    writer.add_scalar('train_acc_epoch', acc, epoch)
    writer.add_scalar('train_f1_epoch', f1, epoch)
    return best_test_accuracy,best_test_f1

def main():
    cudnn.benchmark = False
    cur_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    logdir = os.path.join(params['log']+"_"+str(params["data"])+"_"+str(params['learning_rate'])+"_"+str(params['batch_size'])+"_face", cur_time)
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    writer = SummaryWriter(log_dir=logdir)

    best_train_accuracy = 0
    best_train_f1 = 0
    best_test_accuracy = 0
    best_test_f1 = 0

    # print("Loading dataset")
    train_dataloader=get_train_loader(
        params,
        params['dataset'],
        [224,224],
        params['batch_size'],
        params['num_workers'],
        params['clip_len'],
        64,
        False,
        False,
        True,
        False,
    )
    val_dataloader=get_val_loader(
        params,
        params['dataset'],
        [224,224],
        params['batch_size'],
        params['num_workers'],
        params['clip_len'],
        64,
        False,
        False,
        True,
        False,
    )

    print("load model")
    # model = slowfastnet.resnet50(class_num=params['num_classes'])
    model = model.Ours()
    
    if params['pretrained'] is not None:
        pretrained_dict = torch.load(params['pretrained'], map_location='cpu')
        try:
            model_dict = model.module.state_dict()
        except AttributeError:
            model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        print("load pretrain model")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    os.environ["CUDA_VISIBLE_DEVICES"] = "4,5"


    model = nn.DataParallel(model, device_ids=params['gpu'])

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(), lr=params['learning_rate'], momentum=params['momentum'], weight_decay=params['weight_decay'])

    model_save_dir = os.path.join(params['save_path']+"_"+str(params["data"])+"_"+str(params['learning_rate'])+"_"+str(params['batch_size'])+"_face", cur_time)
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    for epoch in range(params['epoch_num']):
        best_train_accuracy,best_train_f1=train(model, train_dataloader, epoch, criterion, optimizer, writer,best_train_accuracy,best_train_f1)
        if epoch % 2== 0:
            best_test_accuracy,best_test_f1= validation(model, val_dataloader, epoch, criterion, optimizer, writer, best_test_accuracy,best_test_f1)
        # scheduler.step()
        if epoch % 1 == 0 and epoch>params['epoch_savecheckpoint'] and epoch<=130:
            checkpoint = os.path.join(model_save_dir,
                                      "clip_len_" + str(params['clip_len']) + "frame_sample_rate_" +str(params['frame_sample_rate'])+ "_checkpoint_" + str(epoch) + ".pth.tar")
            torch.save(model.module.state_dict(), checkpoint)

    writer.close

if __name__ == '__main__':
    main()
