import os
import numpy as np
import time
from tqdm.auto import tqdm
import argparse
import torch.nn as nn
import torch
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import *
from network import ResNet, FusionModel, Encoder, LinearClassifier
import dataloader

#os.environ['CUDA_VISIBLE_DEVICES'] = '3'


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='UCF101 two stream distillation training')
    parser.add_argument('--name', type=str, help='name of the experiment')
    parser.add_argument('--ckpt', type=str, default='two_stream/ckpt/distill', help='dir to save ckpt')
    parser.add_argument('--num_epochs', default=500, type=int, metavar='N', help='number of total epochs')
    parser.add_argument('--batch_size', default=64, type=int, metavar='N', help='mini-batch size (default: 64)')
    parser.add_argument('--lr', default=1e-2, type=float, metavar='LR', help='initial learning rate')
    parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
    parser.add_argument('--split', default='01', type=str, help='01 02 or 03 split of UCF101')
    parser.add_argument('--rgb_in_channel', default=3, type=int, help='num of input rgb channels')
    parser.add_argument('--flow_in_channel', default=20, type=int, help='num of input rgb channels')
    parser.add_argument('--optimizer', default='sgd', type=str, help='sgd or adam')
    parser.add_argument('--pretrain', default=False, type=bool, help='whether using ImageNet pretrained weight or not')
    parser.add_argument('--flow_distill_weight', type=int, help='weight(lambda) of the flow distillation loss')
    parser.add_argument('--rgb_distill_weight', type=float, help='weight(lambda) of the rgb distillation loss')
    parser.add_argument('--distill_loss', default='mse', type=str, help='distillation loss: l1 or mse')
    parser.add_argument('--net_idx', default='18', type=str, help='ResNet index: 18, 34, 50, 101, 152')
    args = parser.parse_args()
    print(args)

    # dir of saving ckpt
    if not os.path.exists(os.path.join(args.ckpt, args.name)):
        os.makedirs(os.path.join(args.ckpt, args.name))

    # define network
    flow_teacher = Encoder(in_channel=args.flow_in_channel, pretrain=args.pretrain, net_idx=args.net_idx)
    rgb_teacher = Encoder(in_channel=args.rgb_in_channel, pretrain=args.pretrain, net_idx=args.net_idx)
    flow_student = Encoder(in_channel=args.flow_in_channel, pretrain=args.pretrain, net_idx=args.net_idx)
    rgb_student = Encoder(in_channel=args.rgb_in_channel, pretrain=args.pretrain, net_idx=args.net_idx)
    linear = LinearClassifier(1024, 101)
    flow_teacher = nn.DataParallel(flow_teacher).cuda()
    rgb_teacher = nn.DataParallel(rgb_teacher).cuda()
    flow_student = nn.DataParallel(flow_student).cuda()
    rgb_student = nn.DataParallel(rgb_student).cuda()
    linear = nn.DataParallel(linear).cuda()

    # load pre-trained encoder weights to teacher encoders
    flow_ckpt_dir = 'two_stream/ckpt/motion/joint_loader/joint_loader_best.pth'
    rgb_ckpt_dir = 'two_stream/ckpt/spatial/pretrain/pretrain_best.pth'
    flow_ckpt = torch.load(flow_ckpt_dir)
    rgb_ckpt = torch.load(rgb_ckpt_dir)
    flow_teacher.load_state_dict(flow_ckpt['encoder_state_dict'])
    rgb_teacher.load_state_dict(rgb_ckpt['encoder_state_dict'])
    print('\n==> Pre-trained weights are successfully loaded to Flow and RGB teacher encoders\n')

    # loss functions and optimizers
    criterion = nn.CrossEntropyLoss()
    if args.distill_loss == 'mse':
        d_loss = nn.MSELoss()
    elif args.distill_loss == 'l1':
        d_loss = nn.L1Loss()
    else:
        raise Exception('Valid distillation loss: mse or l1')
    #distill_weight = args.distill_weight

    params = (list(flow_student.parameters()) + list(rgb_student.parameters()) + list(linear.parameters()))
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, args.lr, momentum=0.9)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, args.lr)
    else:
        raise KeyError('optimizer not implemented')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=1, verbose=True)

    # prepare data_loader
    data_loader = dataloader.JointDataloader(batch_size=args.batch_size, num_workers=16, num_frames=10,
                                             flow_path='datasets/UCF101/tvl1_flow/',
                                             rgb_path='datasets/UCF101/frame/',
                                             ucf_list='two_stream/UCF_list/',
                                             ucf_split=args.split)
    train_loader, test_loader, test_video = data_loader.run()

    # counters
    start_epoch = 0
    best_prec1 = 0

    # resume training
    if args.resume:
        if os.path.isfile(args.resume):
            print("==> loading checkpoint '{}'".format(args.resume))
            ckpt = torch.load(args.resume)
            start_epoch = ckpt['epoch']
            best_prec1 = ckpt['best_prec1']
            flow_student.load_state_dict(ckpt['flow_student_encoder'])
            rgb_student.load_state_dict(ckpt['rgb_student_encoder'])
            linear.load_state_dict(ckpt['linear_state_dict'])
            optimizer.load_state_dict(ckpt['optimizer'])
            print("==> loaded checkpoint '{}' (epoch {}) (best_prec1 {})".format(args.resume, ckpt['epoch'],
                                                                                 best_prec1))
        else:
            print("==> no checkpoint found at '{}'".format(args.resume))

    # launch training
    cudnn.benchmark = True  # choose the fastest convolution benchmark
    for epoch in range(start_epoch, args.num_epochs):
        print('==> Epoch:[{0}/{1}][training stage]'.format(epoch, args.num_epochs))

        # freeze the teacher encoder weights
        flow_teacher.eval()
        rgb_teacher.eval()

        # switch to training mode
        flow_student.train()
        rgb_student.train()
        linear.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        ce_losses = AverageMeter()
        rgb_distill_losses = AverageMeter()
        flow_distill_losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        end = time.time()
        progress = tqdm(train_loader)

        for flow, rgb_dict, label in progress:
            # measure data loading time
            data_time.update(time.time() - end)

            flow_input = flow.cuda()
            flow_student_feature = flow_student(flow_input)
            with torch.no_grad():
                flow_teacher_feature = flow_teacher(flow_input)

            # sum the scores/features of the 3 RGB + 1 Flow
            rgb_distill_loss = 0.
            output = torch.zeros(args.batch_size, 101).float().cuda()
            for i in range(len(rgb_dict.keys())):
                key = 'img' + str(i)
                rgb_input = rgb_dict[key].cuda()
                rgb_student_feature = rgb_student(rgb_input)
                joint_student_feature = torch.cat((flow_student_feature, rgb_student_feature), 1)
                with torch.no_grad():
                    rgb_teacher_feature = rgb_teacher(rgb_input)
                rgb_distill_loss += args.rgb_distill_weight * d_loss(rgb_student_feature, rgb_teacher_feature)
                output += linear(joint_student_feature)

            # compute average scores
            rgb_distill_loss /= len(rgb_dict.keys())
            output /= len(rgb_dict.keys())  # (64, 101)
            label = label.long().cuda()

            ce_loss = criterion(output, label)
            flow_distill_loss = args.flow_distill_weight * d_loss(flow_student_feature, flow_teacher_feature)
            distill_loss = flow_distill_loss + rgb_distill_loss
            loss = ce_loss + distill_loss

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, label, topk=(1, 5))
            ce_losses.update(ce_loss.item())
            flow_distill_losses.update(flow_distill_loss.item())
            rgb_distill_losses.update(rgb_distill_loss.item())
            top1.update(prec1.item())
            top5.update(prec5.item())

            # compute gradient and back-propagate
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # record training info
        info = {'Epoch': [epoch],
                'Batch Time': [round(batch_time.avg, 3)],
                'Data Time': [round(data_time.avg, 3)],
                'CE Loss': [round(ce_losses.avg, 5)],
                'RGB Distill Loss': [round(rgb_distill_losses.avg, 5)],
                'Flow Distill Loss': [round(flow_distill_losses.avg, 5)],
                'Prec@1': [round(top1.avg, 4)],
                'Prec@5': [round(top5.avg, 4)],
                'lr': optimizer.param_groups[0]['lr']}
        record_info(info, 'record/distill/{}_train.csv'.format(args.name), 'train', distill=True)

        # validation stage
        print('==> Epoch:[{0}/{1}][validation stage]'.format(epoch, args.num_epochs))

        val_batch_time = AverageMeter()
        val_losses = AverageMeter()
        val_top1 = AverageMeter()
        val_top5 = AverageMeter()
        dict_video_level_pred = {}

        # switch to validation mode
        flow_student.eval()
        rgb_student.eval()
        linear.eval()

        val_end = time.time()
        val_progress = tqdm(test_loader)

        for keys, val_flow, val_rgb, val_label in val_progress:

            val_flow = val_flow.cuda()
            val_rgb = val_rgb.cuda()
            val_label = val_label.long().cuda()

            # compute output without storing gradient
            with torch.no_grad():
                val_flow_feature = flow_student(val_flow)
                val_rgb_feature = rgb_student(val_rgb)
                val_joint_feature = torch.cat((val_flow_feature, val_rgb_feature), 1)
                val_output = linear(val_joint_feature)

            # measure elapsed time
            val_batch_time.update(time.time() - val_end)
            val_end = time.time()

            # compute video level prediction
            pred = val_output.cpu().numpy()
            num_data = pred.shape[0]
            for i in range(num_data):
                videoName = keys[i].split('-', 1)[0]  # ApplyMakeup_g01_c01
                if videoName not in dict_video_level_pred.keys():
                    dict_video_level_pred[videoName] = pred[i, :]
                else:
                    dict_video_level_pred[videoName] += pred[i, :]

        video_level_preds = np.zeros((len(dict_video_level_pred), 101))
        video_level_labels = np.zeros(len(dict_video_level_pred))
        ii = 0
        for key in sorted(list(dict_video_level_pred.keys())):
            name = key.split('-', 1)[0]

            preds = dict_video_level_pred[name]
            labels = int(test_video[name]) - 1

            video_level_preds[ii, :] = preds
            video_level_labels[ii] = labels
            ii += 1

        # top1 top5
        video_level_labels = torch.from_numpy(video_level_labels).long()
        video_level_preds = torch.from_numpy(video_level_preds).float()

        # measure accuracy and record loss
        val_loss = criterion(video_level_preds.cuda(), video_level_labels.cuda())
        val_prec1, val_prec5 = accuracy(video_level_preds, video_level_labels, topk=(1, 5))
        val_losses.update(val_loss.item())
        val_top1.update(val_prec1.item())
        val_top5.update(val_prec5.item())

        # record validation info
        info = {'Epoch': [epoch],
                'Batch Time': [round(val_batch_time.avg, 3)],
                'Loss': [round(val_losses.avg, 5)],
                'Prec@1': [round(val_top1.avg, 3)],
                'Prec@5': [round(val_top5.avg, 3)]
                }
        record_info(info, 'record/distill/{}_test.csv'.format(args.name), 'test')

        # save best top1 acc ckpt
        is_best = val_top1.avg > best_prec1
        if is_best:
            best_prec1 = val_top1.avg

        # update lr scheduler
        scheduler.step(val_loss)

        # save current ckpt and best acc ckpt
        save_checkpoint({'epoch': epoch,
                         'flow_student_encoder': flow_student.state_dict(),
                         'rgb_student_encoder': rgb_student.state_dict(),
                         'linear_state_dict': linear.state_dict(),
                         'best_prec1': best_prec1,
                         'optimizer': optimizer.state_dict()},
                        is_best,
                        os.path.join(args.ckpt, args.name, '{}.pth'.format(args.name)),
                        os.path.join(args.ckpt, args.name, '{}_best.pth'.format(args.name)))
