import os
import numpy as np
import time
from tqdm.auto import tqdm
import argparse
import torch.nn as nn
import torch
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import *
from network import ResNet, FusionModel, Encoder, LinearClassifier
import dataloader

os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='UCF101 two stream fusion training')
    parser.add_argument('--name', type=str, help='name of the experiment')
    parser.add_argument('--ckpt', type=str, default='ckpt/naive_fusion', help='dir to save ckpt')
    parser.add_argument('--num_epochs', default=500, type=int, metavar='N', help='number of total epochs')
    parser.add_argument('--batch_size', default=64, type=int, metavar='N', help='mini-batch size (default: 64)')
    parser.add_argument('--lr', default=1e-2, type=float, metavar='LR', help='initial learning rate')
    parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
    parser.add_argument('--split', default='01', type=str, help='01 02 or 03 split of UCF101')
    parser.add_argument('--rgb_in_channel', default=3, type=int, help='num of input rgb channels')
    parser.add_argument('--flow_in_channel', default=20, type=int, help='num of input rgb channels')
    parser.add_argument('--optimizer', default='sgd', type=str, help='sgd or adam')
    parser.add_argument('--pretrain', default=True, type=bool, help='whether using ImageNet pretrained weight or not')
    parser.add_argument('--eval_encoders', default=True, type=bool, help='evaluate uni-modal encoders in naive fusion training')
    parser.add_argument('--net_idx', default='18', type=str, help='ResNet index: 18, 34, 50, 101, 152')
    args = parser.parse_args()
    print(args)

    # dir of saving ckpt
    if not os.path.exists(os.path.join(args.ckpt, args.name)):
        os.makedirs(os.path.join(args.ckpt, args.name))

    # define network
    flow_encoder = Encoder(in_channel=args.flow_in_channel, pretrain=args.pretrain, net_idx=args.net_idx)
    rgb_encoder = Encoder(in_channel=args.rgb_in_channel, pretrain=args.pretrain, net_idx=args.net_idx)
    linear = LinearClassifier(1024, 101)

    flow_encoder = nn.DataParallel(flow_encoder).cuda()
    rgb_encoder = nn.DataParallel(rgb_encoder).cuda()
    linear = nn.DataParallel(linear).cuda()

    # define loss function and optimizer
    criterion = nn.CrossEntropyLoss()

    params = (list(flow_encoder.parameters()) + list(rgb_encoder.parameters()) + list(linear.parameters()))
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, args.lr, momentum=0.9)

    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, args.lr)
    else:
        raise KeyError('optimizer not implemented')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=1, verbose=True)

    # setups for evaluating uni-modal encoders in naive fusion training
    if args.eval_encoders:
        flow_linear = LinearClassifier(512, 101)
        rgb_linear = LinearClassifier(512, 101)
        flow_linear = nn.DataParallel(flow_linear).cuda()
        rgb_linear = nn.DataParallel(rgb_linear).cuda()
        eval_encoders_params = (list(flow_linear.parameters()) + list(rgb_linear.parameters()))
        if args.optimizer == 'sgd':
            eval_encoders_optimizer = torch.optim.SGD(eval_encoders_params, args.lr, momentum=0.9)
        elif args.optimizer == 'adam':
            eval_encoders_optimizer = torch.optim.Adam(eval_encoders_params, args.lr)
        else:
            raise KeyError('optimizer not implemented')
        eval_encoders_scheduler = ReduceLROnPlateau(eval_encoders_optimizer, 'min', patience=1, verbose=True)

    # prepare data_loader
    data_loader = dataloader.JointDataloader(batch_size=args.batch_size, num_workers=16, num_frames=10,
                                             flow_path='datasets/UCF101/tvl1_flow/',
                                             rgb_path='datasets/UCF101/frame/',
                                             ucf_list='UCF_list/',
                                             ucf_split=args.split)
    train_loader, test_loader, test_video = data_loader.run()

    # counters
    start_epoch = 0
    best_prec1 = 0

    # resume training
    if args.resume:
        if os.path.isfile(args.resume):
            if args.eval_encoders:
                print("==> loading checkpoint '{}'".format(args.resume))
                ckpt = torch.load(args.resume)
                start_epoch = ckpt['epoch']
                best_prec1 = ckpt['best_prec1']
                flow_encoder.load_state_dict(ckpt['flow_encoder_state_dict'])
                rgb_encoder.load_state_dict(ckpt['rgb_encoder_state_dict'])
                flow_linear.load_state_dict(ckpt['flow_linear_state_dict'])
                rgb_linear.load_state_dict(ckpt['rgb_linear_state_dict'])
                linear.load_state_dict(ckpt['linear_state_dict'])
                optimizer.load_state_dict(ckpt['optimizer'])
                print("==> loaded checkpoint '{}' (epoch {}) (best_prec1 {})".format(args.resume, ckpt['epoch'],
                                                                                     best_prec1))
            else:
                print("==> loading checkpoint '{}'".format(args.resume))
                ckpt = torch.load(args.resume)
                start_epoch = ckpt['epoch']
                best_prec1 = ckpt['best_prec1']
                flow_encoder.load_state_dict(ckpt['flow_encoder_state_dict'])
                rgb_encoder.load_state_dict(ckpt['rgb_encoder_state_dict'])
                linear.load_state_dict(ckpt['linear_state_dict'])
                optimizer.load_state_dict(ckpt['optimizer'])
                print("==> loaded checkpoint '{}' (epoch {}) (best_prec1 {})".format(args.resume, ckpt['epoch'],
                                                                                     best_prec1))
        else:
            print("==> no checkpoint found at '{}'".format(args.resume))

    # launch training
    cudnn.benchmark = True  # choose the fastest convolution benchmark
    for epoch in range(start_epoch, args.num_epochs):
        print('==> Epoch:[{0}/{1}][training stage]'.format(epoch, args.num_epochs))

        # switch to training mode
        flow_encoder.train()
        rgb_encoder.train()
        linear.train()
        if args.eval_encoders:
            flow_linear.train()
            rgb_linear.train()
            flow_losses = AverageMeter()
            flow_top1 = AverageMeter()
            rgb_losses = AverageMeter()
            rgb_top1 = AverageMeter()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        end = time.time()
        progress = tqdm(train_loader)

        for flow, rgb_dict, label in progress:
            # measure data loading time
            data_time.update(time.time() - end)

            flow_input = flow.cuda()
            flow_feature = flow_encoder(flow_input)

            if args.eval_encoders:
                flow_feature_detach = flow_feature.detach()
                flow_output = flow_linear(flow_feature_detach)
                rgb_output = torch.zeros(args.batch_size, 101).float().cuda()

            # sum the scores of the 3 RGB + 1 Flow
            output = torch.zeros(args.batch_size, 101).float().cuda()
            for i in range(len(rgb_dict.keys())):
                key = 'img' + str(i)
                rgb_input = rgb_dict[key].cuda()
                rgb_feature = rgb_encoder(rgb_input)
                output += linear(torch.cat((flow_feature, rgb_feature), 1))

                if args.eval_encoders:
                    rgb_feature_detach = rgb_feature.detach()
                    rgb_output += rgb_linear(rgb_feature_detach)

            # compute average scores
            output /= len(rgb_dict.keys())  # (64, 101)
            label = label.long().cuda()

            loss = criterion(output, label)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, label, topk=(1, 5))
            losses.update(loss.item())
            top1.update(prec1.item())
            top5.update(prec5.item())

            # compute gradient and back-propagate
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # evaluate uni-modal encoders
            if args.eval_encoders:
                rgb_output /= len(rgb_dict.keys())
                eval_flow_loss = criterion(flow_output, label)
                eval_rgb_loss = criterion(rgb_output, label)
                eval_encoders_loss = eval_flow_loss + eval_rgb_loss
                flow_prec1, flow_prec5 = accuracy(flow_output, label, topk=(1, 5))
                rgb_prec1, rgb_prec5 = accuracy(rgb_output, label, topk=(1, 5))
                flow_losses.update(eval_flow_loss.item())
                rgb_losses.update(eval_rgb_loss.item())
                flow_top1.update(flow_prec1.item())
                rgb_top1.update(rgb_prec1.item())
                eval_encoders_optimizer.zero_grad()
                eval_encoders_loss.backward()
                eval_encoders_optimizer.step()

        # record training info
        if args.eval_encoders:
            info = {'Epoch': [epoch],
                    'Flow Loss': [round(flow_losses.avg, 3)],
                    'Flow Prec@1': [round(flow_top1.avg, 4)],
                    'RGB Loss': [round(rgb_losses.avg, 3)],
                    'RGB Prec@1': [round(rgb_top1.avg, 4)],
                    'Loss': [round(losses.avg, 4)],
                    'Prec@1': [round(top1.avg, 4)],
                    'Prec@5': [round(top5.avg, 4)],
                    'lr': optimizer.param_groups[0]['lr']}
            record_info(info, 'record/naive_fusion/{}_train.csv'.format(args.name), 'train',
                        eval_naive_encoders=args.eval_encoders)
        else:
            info = {'Epoch': [epoch],
                    'Batch Time': [round(batch_time.avg, 3)],
                    'Data Time': [round(data_time.avg, 3)],
                    'Loss': [round(losses.avg, 5)],
                    'Prec@1': [round(top1.avg, 4)],
                    'Prec@5': [round(top5.avg, 4)],
                    'lr': optimizer.param_groups[0]['lr']}
            record_info(info, 'record/naive_fusion/{}_train.csv'.format(args.name), 'train')

        # validation stage
        print('==> Epoch:[{0}/{1}][validation stage]'.format(epoch, args.num_epochs))

        val_batch_time = AverageMeter()
        val_losses = AverageMeter()
        val_top1 = AverageMeter()
        val_top5 = AverageMeter()
        dict_video_level_pred = {}

        # switch to validation mode
        flow_encoder.eval()
        rgb_encoder.eval()
        linear.eval()
        if args.eval_encoders:
            flow_linear.eval()
            rgb_linear.eval()
            val_flow_top1 = AverageMeter()
            val_rgb_top1 = AverageMeter()
            dict_video_level_pred_flow = {}
            dict_video_level_pred_rgb = {}

        val_end = time.time()
        val_progress = tqdm(test_loader)

        for keys, val_flow, val_rgb, val_label in val_progress:

            val_flow = val_flow.cuda()
            val_rgb = val_rgb.cuda()
            val_label = val_label.long().cuda()

            # compute output without storing gradient
            with torch.no_grad():
                val_flow_feature = flow_encoder(val_flow)
                val_rgb_feature = rgb_encoder(val_rgb)
                val_output = linear(torch.cat((val_flow_feature, val_rgb_feature), 1))
                if args.eval_encoders:
                    val_flow_output = flow_linear(val_flow_feature)
                    val_rgb_output = rgb_linear(val_rgb_feature)

            # measure elapsed time
            val_batch_time.update(time.time() - val_end)
            val_end = time.time()

            # compute video level prediction
            pred = val_output.cpu().numpy()
            num_data = pred.shape[0]
            if args.eval_encoders:
                pred_flow = val_flow_output.cpu().numpy()
                pred_rgb = val_rgb_output.cpu().numpy()
            for i in range(num_data):
                videoName = keys[i].split('-', 1)[0]  # ApplyMakeup_g01_c01
                if videoName not in dict_video_level_pred.keys():
                    dict_video_level_pred[videoName] = pred[i, :]
                    if args.eval_encoders:
                        dict_video_level_pred_flow[videoName] = pred_flow[i, :]
                        dict_video_level_pred_rgb[videoName] = pred_rgb[i, :]
                else:
                    dict_video_level_pred[videoName] += pred[i, :]
                    if args.eval_encoders:
                        dict_video_level_pred_flow[videoName] += pred_flow[i, :]
                        dict_video_level_pred_rgb[videoName] += pred_rgb[i, :]

        video_level_preds = np.zeros((len(dict_video_level_pred), 101))
        video_level_labels = np.zeros(len(dict_video_level_pred))
        if args.eval_encoders:
            video_level_preds_flow = np.zeros((len(dict_video_level_pred_flow), 101))
            video_level_preds_rgb = np.zeros((len(dict_video_level_pred_rgb), 101))
        ii = 0
        for key in sorted(list(dict_video_level_pred.keys())):
            name = key.split('-', 1)[0]

            preds = dict_video_level_pred[name]
            labels = int(test_video[name]) - 1

            video_level_preds[ii, :] = preds
            video_level_labels[ii] = labels

            if args.eval_encoders:
                preds_flow = dict_video_level_pred_flow[name]
                preds_rgb = dict_video_level_pred_rgb[name]
                video_level_preds_flow[ii, :] = preds_flow
                video_level_preds_rgb[ii, :] = preds_rgb
            ii += 1

        # top1 top5
        video_level_labels = torch.from_numpy(video_level_labels).long()
        video_level_preds = torch.from_numpy(video_level_preds).float()
        if args.eval_encoders:
            video_level_preds_flow = torch.from_numpy(video_level_preds_flow).float()
            video_level_preds_rgb = torch.from_numpy(video_level_preds_rgb).float()

        # measure accuracy and record loss
        val_loss = criterion(video_level_preds.cuda(), video_level_labels.cuda())
        val_prec1, val_prec5 = accuracy(video_level_preds, video_level_labels, topk=(1, 5))
        if args.eval_encoders:
            val_flow_prec1, _ = accuracy(video_level_preds_flow, video_level_labels, topk=(1, 5))
            val_rgb_prec1, _ = accuracy(video_level_preds_rgb, video_level_labels, topk=(1, 5))
            val_flow_top1.update(val_flow_prec1.item())
            val_rgb_top1.update(val_rgb_prec1.item())
        val_losses.update(val_loss.item())
        val_top1.update(val_prec1.item())
        val_top5.update(val_prec5.item())

        # record validation info
        if args.eval_encoders:
            info = {'Epoch': [epoch],
                    'Loss': [round(val_losses.avg, 5)],
                    'Prec@1': [round(val_top1.avg, 3)],
                    'Prec@5': [round(val_top5.avg, 3)],
                    'Flow Prec@1': [round(val_flow_top1.avg, 3)],
                    'RGB Prec@1': [round(val_rgb_top1.avg, 3)]
                    }
        else:
            info = {'Epoch': [epoch],
                    'Batch Time': [round(val_batch_time.avg, 3)],
                    'Loss': [round(val_losses.avg, 5)],
                    'Prec@1': [round(val_top1.avg, 3)],
                    'Prec@5': [round(val_top5.avg, 3)]
                    }
        record_info(info, 'record/naive_fusion/{}_test.csv'.format(args.name), 'test',
                    eval_naive_encoders=args.eval_encoders)

        # save best top1 acc ckpt
        is_best = val_top1.avg > best_prec1
        if is_best:
            best_prec1 = val_top1.avg

        # update lr scheduler
        scheduler.step(val_loss)
        if args.eval_encoders:
            eval_encoders_scheduler.step(val_loss)

        # save current ckpt and best acc ckpt

        if args.eval_encoders:
            save_checkpoint({'epoch': epoch,
                             'flow_encoder_state_dict': flow_encoder.state_dict(),
                             'rgb_encoder_state_dict': rgb_encoder.state_dict(),
                             'flow_linear_state_dict': flow_linear.state_dict(),
                             'rgb_linear_state_dict': rgb_linear.state_dict(),
                             'linear_state_dict': linear.state_dict(),
                             'best_prec1': best_prec1,
                             'optimizer': optimizer.state_dict()},
                            is_best,
                            os.path.join(args.ckpt, args.name, '{}.pth'.format(args.name)),
                            os.path.join(args.ckpt, args.name, '{}_best.pth'.format(args.name)))
        else:
            save_checkpoint({'epoch': epoch,
                             'flow_encoder_state_dict': flow_encoder.state_dict(),
                             'rgb_encoder_state_dict': rgb_encoder.state_dict(),
                             'linear_state_dict': linear.state_dict(),
                             'best_prec1': best_prec1,
                             'optimizer': optimizer.state_dict()},
                            is_best,
                            os.path.join(args.ckpt, args.name, '{}.pth'.format(args.name)),
                            os.path.join(args.ckpt, args.name, '{}_best.pth'.format(args.name)))
