import os
import numpy as np
import pickle
import time
from tqdm.auto import tqdm
import argparse
import torch.nn as nn
import torch
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import *
from network import ResNet, Encoder, LinearClassifier
import dataloader

#os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='UCF101 spatial stream')
    parser.add_argument('--name', type=str, help='name of the experiment')
    parser.add_argument('--ckpt', type=str, default='two_stream/ckpt/spatial', help='dir to save ckpt')
    parser.add_argument('--num_epochs', default=500, type=int, metavar='N', help='number of total epochs')
    parser.add_argument('--batch_size', default=64, type=int, metavar='N', help='mini-batch size (default: 64)')
    parser.add_argument('--lr', default=1e-2, type=float, metavar='LR', help='initial learning rate')
    parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
    parser.add_argument('--split', default='01', type=str, help='01 02 or 03 split of UCF101')
    parser.add_argument('--in_channel', default=3, type=int, help='num of input rgb channels')
    parser.add_argument('--optimizer', default='sgd', type=str, help='sgd or adam')
    parser.add_argument('--pretrain', default=False, type=bool, help='whether using ImageNet pretrained weight or not')
    parser.add_argument('--net_idx', default='18', type=str, help='ResNet index: 18, 34, 50, 101, 152')
    args = parser.parse_args()
    print(args)

    # dir of saving ckpt
    if not os.path.exists(os.path.join(args.ckpt, args.name)):
        os.makedirs(os.path.join(args.ckpt, args.name))

    # define network
    #net = ResNet(101, pretrain=args.pretrain, in_channel=args.in_channel, net_idx=args.net_idx)
    encoder = Encoder(in_channel=args.in_channel, pretrain=args.pretrain, net_idx=args.net_idx)
    linear = LinearClassifier(512, 101)
    encoder = nn.DataParallel(encoder).cuda()
    linear = nn.DataParallel(linear).cuda()

    # define loss function and optimizer
    criterion = nn.CrossEntropyLoss()

    params = (list(encoder.parameters()) + list(linear.parameters()))
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, args.lr, momentum=0.9)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, args.lr)
    else:
        raise KeyError('optimizer not implemented')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=1, verbose=True)

    # prepare data_loader
    data_loader = dataloader.JointDataloader(batch_size=args.batch_size, num_workers=16, num_frames=10,
                                             flow_path='datasets/UCF101/tvl1_flow/',
                                             rgb_path='datasets/UCF101/frame/',
                                             ucf_list='two_stream/UCF_list/',
                                             ucf_split=args.split)
    train_loader, test_loader, test_video = data_loader.run()

    # counters
    start_epoch = 0
    best_prec1 = 0

    # resume training
    if args.resume:
        if os.path.isfile(args.resume):
            print("==> loading checkpoint '{}'".format(args.resume))
            ckpt = torch.load(args.resume)
            start_epoch = ckpt['epoch']
            best_prec1 = ckpt['best_prec1']
            encoder.load_state_dict(ckpt['encoder_state_dict'])
            linear.load_state_dict(ckpt['linear_state_dict'])
            optimizer.load_state_dict(ckpt['optimizer'])
            print("==> loaded checkpoint '{}' (epoch {}) (best_prec1 {})".format(args.resume, ckpt['epoch'],
                                                                                 best_prec1))
        else:
            print("==> no checkpoint found at '{}'".format(args.resume))

    # launch training
    cudnn.benchmark = True  # choose the fastest convolution benchmark
    for epoch in range(start_epoch, args.num_epochs):
        print('==> Epoch:[{0}/{1}][training stage]'.format(epoch, args.num_epochs))

        # switch to training mode
        encoder.train()
        linear.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        end = time.time()
        progress = tqdm(train_loader)

        for _, data_dict, label in progress:
            # measure data loading time
            data_time.update(time.time() - end)

            # sum the softmax scores of the 3 images
            output = torch.zeros(args.batch_size, 101).float().cuda()
            for i in range(len(data_dict.keys())):
                key = 'img' + str(i)
                input = data_dict[key].cuda()
                output += linear(encoder(input))

            # compute average softmax scores
            output /= len(data_dict.keys())  # (64, 101)
            label = label.long().cuda()

            loss = criterion(output, label)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, label, topk=(1, 5))
            losses.update(loss.item())
            top1.update(prec1.item())
            top5.update(prec5.item())

            # compute gradient and back-propagate
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

        # record training info
        info = {'Epoch': [epoch],
                'Batch Time': [round(batch_time.avg, 3)],
                'Data Time': [round(data_time.avg, 3)],
                'Loss': [round(losses.avg, 5)],
                'Prec@1': [round(top1.avg, 4)],
                'Prec@5': [round(top5.avg, 4)],
                'lr': optimizer.param_groups[0]['lr']}
        record_info(info, 'record/spatial/{}_train.csv'.format(args.name), 'train')

        # validation stage
        print('==> Epoch:[{0}/{1}][validation stage]'.format(epoch, args.num_epochs))

        val_batch_time = AverageMeter()
        val_losses = AverageMeter()
        val_top1 = AverageMeter()
        val_top5 = AverageMeter()
        dict_video_level_pred = {}

        # switch to validation mode
        encoder.eval()
        linear.eval()

        val_end = time.time()
        val_progress = tqdm(test_loader)
        for keys, _, val_data, val_label in val_progress:

            val_input = val_data.cuda()
            val_label = val_label.long().cuda()

            # compute output without storing gradient
            with torch.no_grad():
                val_output = linear(encoder(val_input))
                #val_loss = criterion(val_output, val_label)

            # measure elapsed time
            val_batch_time.update(time.time() - val_end)
            val_end = time.time()

            # compute video level prediction
            pred = val_output.cpu().numpy()
            num_data = pred.shape[0]
            for i in range(num_data):
                videoName = keys[i].split('-', 1)[0]  # ApplyMakeup_g01_c01
                if videoName not in dict_video_level_pred.keys():
                    dict_video_level_pred[videoName] = pred[i, :]
                else:
                    dict_video_level_pred[videoName] += pred[i, :]

        video_level_preds = np.zeros((len(dict_video_level_pred), 101))
        video_level_labels = np.zeros(len(dict_video_level_pred))
        ii = 0
        for key in sorted(list(dict_video_level_pred.keys())):
            name = key.split('-', 1)[0]

            preds = dict_video_level_pred[name]
            labels = int(test_video[name]) - 1

            video_level_preds[ii, :] = preds
            video_level_labels[ii] = labels
            ii += 1

        # top1 top5
        video_level_labels = torch.from_numpy(video_level_labels).long()
        video_level_preds = torch.from_numpy(video_level_preds).float()

        # measure accuracy and record loss
        val_loss = criterion(video_level_preds.cuda(), video_level_labels.cuda())
        val_prec1, val_prec5 = accuracy(video_level_preds, video_level_labels, topk=(1, 5))
        val_losses.update(val_loss.item())
        val_top1.update(val_prec1.item())
        val_top5.update(val_prec5.item())

        # record validation info
        info = {'Epoch': [epoch],
                'Batch Time': [round(val_batch_time.avg, 3)],
                'Loss': [round(val_losses.avg, 5)],
                'Prec@1': [round(val_top1.avg, 3)],
                'Prec@5': [round(val_top5.avg, 3)]
                }
        record_info(info, 'record/spatial/{}_test.csv'.format(args.name), 'test')

        # save best top1 acc ckpt
        is_best = val_top1.avg > best_prec1
        if is_best:
            best_prec1 = val_top1.avg

        # update lr scheduler
        scheduler.step(val_loss)

        # save current ckpt and best acc ckpt
        save_checkpoint({'epoch': epoch,
                         'encoder_state_dict': encoder.state_dict(),
                         'linear_state_dict': linear.state_dict(),
                         'best_prec1': best_prec1,
                         'optimizer': optimizer.state_dict()},
                        is_best,
                        os.path.join(args.ckpt, args.name, '{}.pth'.format(args.name)),
                        os.path.join(args.ckpt, args.name, '{}_best.pth'.format(args.name)))

