import os
import numpy as np
import time
from tqdm.auto import tqdm
import argparse
import torch.nn as nn
import torch
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import *
from network import ResNet, FusionModel, Encoder, LinearClassifier
import dataloader

#os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='UCF101 two stream fusion training')
    parser.add_argument('--name', type=str, help='name of the experiment')
    parser.add_argument('--ckpt', type=str, default='two_stream/ckpt/3CE_loss', help='dir to save ckpt')
    parser.add_argument('--num_epochs', default=500, type=int, metavar='N', help='number of total epochs')
    parser.add_argument('--batch_size', default=64, type=int, metavar='N', help='mini-batch size (default: 64)')
    parser.add_argument('--lr', default=1e-2, type=float, metavar='LR', help='initial learning rate')
    parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
    parser.add_argument('--split', default='01', type=str, help='01 02 or 03 split of UCF101')
    parser.add_argument('--rgb_in_channel', default=3, type=int, help='num of input rgb channels')
    parser.add_argument('--flow_in_channel', default=20, type=int, help='num of input rgb channels')
    parser.add_argument('--optimizer', default='sgd', type=str, help='sgd or adam')
    parser.add_argument('--pretrain', default=False, type=bool, help='whether using ImageNet pretrained weight or not')
    parser.add_argument('--extra_loss_weight', default=1, type=int, help='weight on the uni-modal CE losses')
    parser.add_argument('--net_idx', default='18', type=str, help='ResNet index: 18, 34, 50, 101, 152')
    args = parser.parse_args()
    print(args)

    # dir of saving ckpt
    if not os.path.exists(os.path.join(args.ckpt, args.name)):
        os.makedirs(os.path.join(args.ckpt, args.name))

    # define network
    flow_encoder = Encoder(in_channel=args.flow_in_channel, pretrain=args.pretrain, net_idx=args.net_idx)
    rgb_encoder = Encoder(in_channel=args.rgb_in_channel, pretrain=args.pretrain, net_idx=args.net_idx)
    flow_linear = LinearClassifier(512, 101)
    rgb_linear = LinearClassifier(512, 101)
    linear = LinearClassifier(1024, 101)
    flow_encoder = nn.DataParallel(flow_encoder).cuda()
    rgb_encoder = nn.DataParallel(rgb_encoder).cuda()
    flow_linear = nn.DataParallel(flow_linear).cuda()
    rgb_linear = nn.DataParallel(rgb_linear).cuda()
    linear = nn.DataParallel(linear).cuda()

    # define loss function and optimizer
    criterion = nn.CrossEntropyLoss()

    params = (list(flow_encoder.parameters()) + list(rgb_encoder.parameters()) +
              list(flow_linear.parameters()) + list(rgb_linear.parameters()) + list(linear.parameters()))
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, args.lr, momentum=0.9)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, args.lr)
    else:
        raise KeyError('optimizer not implemented')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=1, verbose=True)

    # prepare data_loader
    data_loader = dataloader.JointDataloader(batch_size=args.batch_size, num_workers=16, num_frames=10,
                                             flow_path='datasets/UCF101/tvl1_flow/',
                                             rgb_path='datasets/UCF101/frame/',
                                             ucf_list='two_stream/UCF_list/',
                                             ucf_split=args.split)
    train_loader, test_loader, test_video = data_loader.run()

    # counters
    start_epoch = 0
    best_prec1 = 0

    # resume training
    if args.resume:
        if os.path.isfile(args.resume):
            print("==> loading checkpoint '{}'".format(args.resume))
            ckpt = torch.load(args.resume)
            start_epoch = ckpt['epoch']
            best_prec1 = ckpt['best_prec1']
            flow_encoder.load_state_dict(ckpt['flow_encoder_state_dict'])
            rgb_encoder.load_state_dict(ckpt['rgb_encoder_state_dict'])
            flow_linear.load_state_dict(ckpt['flow_linear_state_dict'])
            rgb_linear.load_state_dict(ckpt['rgb_linear_state_dict'])
            linear.load_state_dict(ckpt['linear_state_dict'])
            optimizer.load_state_dict(ckpt['optimizer'])
            print("==> loaded checkpoint '{}' (epoch {}) (best_prec1 {})".format(args.resume, ckpt['epoch'],
                                                                                 best_prec1))
        else:
            print("==> no checkpoint found at '{}'".format(args.resume))

    # launch training
    cudnn.benchmark = True  # choose the fastest convolution benchmark
    for epoch in range(start_epoch, args.num_epochs):
        print('==> Epoch:[{0}/{1}][training stage]'.format(epoch, args.num_epochs))

        # switch to training mode
        flow_encoder.train()
        rgb_encoder.train()
        flow_linear.train()
        rgb_linear.train()
        linear.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        flow_losses = AverageMeter()
        rgb_losses = AverageMeter()
        joint_losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        end = time.time()
        progress = tqdm(train_loader)

        for flow, rgb_dict, label in progress:
            # measure data loading time
            data_time.update(time.time() - end)

            # forward flow
            flow_input = flow.cuda()
            flow_feature = flow_encoder(flow_input)
            flow_output = flow_linear(flow_feature)

            # sum the scores of the 3 RGB frames
            rgb_output = torch.zeros(args.batch_size, 101).float().cuda()
            output = torch.zeros(args.batch_size, 101).float().cuda()
            for i in range(len(rgb_dict.keys())):
                key = 'img' + str(i)
                rgb_input = rgb_dict[key].cuda()
                rgb_feature = rgb_encoder(rgb_input)
                rgb_output += rgb_linear(rgb_feature)
                output += linear(torch.cat((flow_feature, rgb_feature), 1))

            # compute average scores
            rgb_output /= len(rgb_dict.keys())
            output /= len(rgb_dict.keys())  # (64, 101)
            label = label.long().cuda()

            flow_loss = criterion(flow_output, label) #* args.extra_loss_weight
            rgb_loss = criterion(rgb_output, label) #* args.extra_loss_weight
            joint_loss = criterion(output, label)
            loss = flow_loss + rgb_loss + joint_loss

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, label, topk=(1, 5))
            flow_losses.update(flow_loss.item())
            rgb_losses.update(rgb_loss.item())
            joint_losses.update(joint_loss.item())
            top1.update(prec1.item())
            top5.update(prec5.item())

            # compute gradient and back-propagate
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # record training info
        info = {'Epoch': [epoch],
                'Batch Time': [round(batch_time.avg, 3)],
                'Data Time': [round(data_time.avg, 3)],
                'Joint Loss': [round(joint_losses.avg, 5)],
                'Flow Loss': [round(flow_losses.avg, 5)],
                'RGB Loss': [round(rgb_losses.avg, 5)],
                'Prec@1': [round(top1.avg, 4)],
                'Prec@5': [round(top5.avg, 4)],
                'lr': optimizer.param_groups[0]['lr']}
        record_info(info, 'record/3CE_loss/{}_train.csv'.format(args.name), 'train', triple_ce=True)

        # validation stage
        print('==> Epoch:[{0}/{1}][validation stage]'.format(epoch, args.num_epochs))

        val_batch_time = AverageMeter()
        val_losses = AverageMeter()
        val_top1 = AverageMeter()
        val_top5 = AverageMeter()
        dict_video_level_pred = {}

        # switch to validation mode
        flow_encoder.eval()
        rgb_encoder.eval()
        linear.eval()

        val_end = time.time()
        val_progress = tqdm(test_loader)

        for keys, val_flow, val_rgb, val_label in val_progress:

            val_flow = val_flow.cuda()
            val_rgb = val_rgb.cuda()
            val_label = val_label.long().cuda()

            # compute output without storing gradient
            with torch.no_grad():
                val_output = linear(torch.cat((flow_encoder(val_flow),
                                               rgb_encoder(val_rgb)), 1))
                #val_loss = criterion(val_output, val_label)

            # measure elapsed time
            val_batch_time.update(time.time() - val_end)
            val_end = time.time()

            # compute video level prediction
            pred = val_output.cpu().numpy()
            num_data = pred.shape[0]
            for i in range(num_data):
                videoName = keys[i].split('-', 1)[0]  # ApplyMakeup_g01_c01
                if videoName not in dict_video_level_pred.keys():
                    dict_video_level_pred[videoName] = pred[i, :]
                else:
                    dict_video_level_pred[videoName] += pred[i, :]

        video_level_preds = np.zeros((len(dict_video_level_pred), 101))
        video_level_labels = np.zeros(len(dict_video_level_pred))
        ii = 0
        for key in sorted(list(dict_video_level_pred.keys())):
            name = key.split('-', 1)[0]

            preds = dict_video_level_pred[name]
            labels = int(test_video[name]) - 1

            video_level_preds[ii, :] = preds
            video_level_labels[ii] = labels
            ii += 1

        # top1 top5
        video_level_labels = torch.from_numpy(video_level_labels).long()
        video_level_preds = torch.from_numpy(video_level_preds).float()

        # measure accuracy and record loss
        val_loss = criterion(video_level_preds.cuda(), video_level_labels.cuda())
        val_prec1, val_prec5 = accuracy(video_level_preds, video_level_labels, topk=(1, 5))
        val_losses.update(val_loss.item())
        val_top1.update(val_prec1.item())
        val_top5.update(val_prec5.item())

        # record validation info
        info = {'Epoch': [epoch],
                'Batch Time': [round(val_batch_time.avg, 3)],
                'Loss': [round(val_losses.avg, 5)],
                'Prec@1': [round(val_top1.avg, 3)],
                'Prec@5': [round(val_top5.avg, 3)]
                }
        record_info(info, 'record/3CE_loss/{}_test.csv'.format(args.name), 'test')

        # save best top1 acc ckpt
        is_best = val_top1.avg > best_prec1
        if is_best:
            best_prec1 = val_top1.avg

        # update lr scheduler
        scheduler.step(val_loss)

        # save current ckpt and best acc ckpt
        save_checkpoint({'epoch': epoch,
                         'flow_encoder_state_dict': flow_encoder.state_dict(),
                         'rgb_encoder_state_dict': rgb_encoder.state_dict(),
                         'flow_linear_state_dict': flow_linear.state_dict(),
                         'rgb_linear_state_dict': rgb_linear.state_dict(),
                         'linear_state_dict': linear.state_dict(),
                         'best_prec1': best_prec1,
                         'optimizer': optimizer.state_dict()},
                        is_best,
                        os.path.join(args.ckpt, args.name, '{}.pth'.format(args.name)),
                        os.path.join(args.ckpt, args.name, '{}_best.pth'.format(args.name)))
