# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
import argparse
import os
import pprint
import shutil
import copy
import numpy as np
import logging

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter

import _init_paths
from config import cfg
from config import update_config
from core.loss import FinetuneUnsupLoss
from utils.utils import get_optimizer
from utils.utils import save_checkpoint
from utils.utils import create_logger
from utils.utils import get_model_summary

from core.evaluate import accuracy
from core.inference import get_final_preds
from utils.transforms import flip_back
from utils.vis import save_debug_images

import dataset
import models

logger = logging.getLogger(__name__)

def test_time_training(config, loader, dataset, model, model_state_dict,
          criterion, optimizer, optimizer_state_dict, output_dir,
          tb_log_dir, writer_dict):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses0 = AverageMeter()
    losses1 = AverageMeter()
    acc0 = AverageMeter()
    acc1 = AverageMeter()

    # switch to train mode
    # `model.pose_net` (except for unsup head)
    # should be automatically set to eval mode in `train()`
    model.train()

    # for ttt dataset, dataset.length == len(dataset) // ttt_batchsize
    # also we downsample by 8, otherwise it's too slow
    num_samples = dataset.length // 8
    all_preds0 = np.zeros(
        (num_samples, config.MODEL.NUM_JOINTS, 3),
        dtype=np.float32
    )
    all_preds1 = np.zeros(
        (num_samples, config.MODEL.NUM_JOINTS, 3),
        dtype=np.float32
    )
    all_boxes = np.zeros((num_samples, 6))
    idx = 0

    end = time.time()
    for i, (input, target, target_weight, meta) in enumerate(loader):
        if i % 8 == 7:
            # measure data loading time
            data_time.update(time.time() - end)

            input = input
            target = target
            target_weight = target_weight

            ref_images = input[:, 0, :, :].cuda(non_blocking=True)
            images = input[:, 1, :, :].cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)
            target_weight = target_weight.cuda(non_blocking=True)

            output, pred_images = model(images, ref_images)
            _, avg_acc, cnt, pred0 = accuracy(output[0:1].detach().cpu().numpy(),
                                            target[0:1].detach().cpu().numpy())
            acc0.update(avg_acc, cnt)

            # update model
            loss0 = criterion(None, None, None, images, pred_images, None)

            optimizer.zero_grad()
            loss0.backward()
            optimizer.step()

            with torch.no_grad():
                new_output, new_pred_images = model(images, ref_images)
                loss1 = criterion(None, None, None, images, new_pred_images, None)
                _, avg_acc, cnt, pred1 = accuracy(new_output[0:1].detach().cpu().numpy(),
                                                target[0:1].detach().cpu().numpy())
                acc1.update(avg_acc, cnt)

            # measure accuracy and record loss
            losses0.update(loss0.item(), input.size(0))
            losses1.update(loss1.item(), input.size(0))
            if loss1 > loss0:
                new_output = output
            
            c = meta['center'][0:1].numpy()
            s = meta['scale'][0:1].numpy()
            score = meta['score'][0:1].numpy()

            preds, maxvals = get_final_preds(
                config, output[0:1].clone().detach().cpu().numpy(), c, s)
            all_preds0[idx:idx + 1, :, 0:2] = preds[:, :, 0:2]
            all_preds0[idx:idx + 1, :, 2:3] = maxvals

            preds, maxvals = get_final_preds(
                config, new_output[0:1].clone().detach().cpu().numpy(), c, s)
            all_preds1[idx:idx + 1, :, 0:2] = preds[:, :, 0:2]
            all_preds1[idx:idx + 1, :, 2:3] = maxvals

            # double check this all_boxes parts
            all_boxes[idx:idx + 1, 0:2] = c[:, 0:2]
            all_boxes[idx:idx + 1, 2:4] = s[:, 0:2]
            all_boxes[idx:idx + 1, 4] = np.prod(s*200, 1)
            all_boxes[idx:idx + 1, 5] = score
            idx += 1
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            # if i % config.PRINT_FREQ == 0:
            if True:
                msg = 'Iter: [{0}/{1}]\t' \
                    'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                    'Speed {speed:.1f} samples/s\t' \
                    'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
                    'Loss0 {loss0.val:.5f} ({loss0.avg:.5f})\t' \
                    'Loss1 {loss1.val:.5f} ({loss1.avg:.5f})\t' \
                    'Accuracy0 {acc0.val:.3f} ({acc0.avg:.3f})' \
                    'Accuracy1 {acc1.val:.3f} ({acc1.avg:.3f})'.format(
                        i, num_samples, batch_time=batch_time,
                        speed=input.size(0)/batch_time.val,
                        data_time=data_time, loss0=losses0, loss1=losses1,
                        acc0=acc0, acc1=acc1,)
                if acc0.val != acc1.val:
                    logger.info(msg)
                writer = writer_dict['writer']
                global_steps = writer_dict['train_global_steps']
                writer.add_scalar('train_loss0', losses0.val, global_steps)
                writer.add_scalar('train_acc0', acc0.val, global_steps)
                writer.add_scalar('train_loss1', losses1.val, global_steps)
                writer.add_scalar('train_acc1', acc1.val, global_steps)
                writer_dict['train_global_steps'] = global_steps + 1
                prefix = '{}_{}'.format(os.path.join(output_dir, 'train'), i)
            meta_first_sample = {key: meta[key][0:1] for key in meta}
            if acc0.val > acc1.val:
                debug_images = save_debug_images(config, images[0:1], meta_first_sample, target[0:1],
                                pred0*4, output[0:1], prefix, ref_images[0:1], pred_images[0:1], new_pred=pred1*4)
            # for key in debug_images:
            #     writer.add_image("TTT/" + key, np.transpose(debug_images[key], (2, 0, 1))[[2,1,0]], global_steps)

        # reset weight at the end of each video
        if meta["is_last_frame"][0]:
            logger.info("reaching end of a video, reiniting weight")
            model.load_state_dict(model_state_dict)
            optimizer.load_state_dict(optimizer_state_dict)

    name_values, perf_indicator = dataset.evaluate(
            config, all_preds0, output_dir, downsample=8
        )

    name_values, perf_indicator = dataset.evaluate(
        config, all_preds1, output_dir, downsample=8
    )

    # model_name = config.MODEL.NAME
    # if isinstance(name_values, list):
    #     for name_value in name_values:
    #         _print_name_value(name_value, model_name)
    # else:
    #     _print_name_value(name_values, model_name)

    # if writer_dict:
    #     writer = writer_dict['writer']
    #     global_steps = writer_dict['train_global_steps']
    #     writer.add_scalar(
    #         'valid_loss',
    #         losses.avg,
    #         global_steps
    #     )
    #     writer.add_scalar(
    #         'valid_acc',
    #         acc.avg,
    #         global_steps
    #     )
    #     if isinstance(name_values, list):
    #         for name_value in name_values:
    #             writer.add_scalars(
    #                 'valid',
    #                 dict(name_value),
    #                 global_steps
    #             )
    #     else:
    #         writer.add_scalars(
    #             'valid',
    #             dict(name_values),
    #             global_steps
    #         )
    #     writer_dict['train_global_steps'] = global_steps + 1
    
    return perf_indicator


def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        required=True,
                        type=str)

    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    # philly
    parser.add_argument('--modelDir',
                        help='model directory',
                        type=str,
                        default='')
    parser.add_argument('--logDir',
                        help='log directory',
                        type=str,
                        default='')
    parser.add_argument('--dataDir',
                        help='data directory',
                        type=str,
                        default='')
    parser.add_argument('--prevModelDir',
                        help='prev Model directory',
                        type=str,
                        default='')

    args = parser.parse_args()

    return args


def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    # reproducibility
    torch.manual_seed(0)
    np.random.seed(0)

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=True, freeze_bn=True, is_ttt=True,
    )
    
    logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
    model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
    model.pose_net.load_state_dict(torch.load(cfg.TEST.POSE_NET_FILE), strict=True)

    for p in model.parameters():
        p.requires_grad = False
    for p in model.pose_net.parameters():
        p.requires_grad = True

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
    }

    if len(cfg.GPUS) > 1:
        raise NotImplementedError
    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # define loss function (criterion) and optimizer
    criterion = FinetuneUnsupLoss(cfg).cuda()

    # Data loading code
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]),
        is_ttt=True,
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY
    )

    optimizer = get_optimizer(cfg, model)

    # use these to re-init after every video
    model_state_dict = copy.deepcopy(model.state_dict())
    optimizer_state_dict = copy.deepcopy(optimizer.state_dict())

    test_time_training(cfg, loader, dataset, model, model_state_dict, criterion,
             optimizer, optimizer_state_dict, final_output_dir, tb_log_dir, writer_dict)

    writer_dict['writer'].close()



class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count if self.count != 0 else 0



if __name__ == '__main__':
    main()
