import argparse
import os
import pprint
import shutil
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from tensorboardX import SummaryWriter

import tools._init_paths
from config import cfg
from config import update_config
from core.loss import JointsMSELoss
from core.function_semi import train, validate
from utils.utils import get_optimizer
from utils.utils import save_checkpoint
from utils.utils import create_logger
from utils.utils import read_labelled_split
from core.consistency_loss import ConsistencyLoss
from utils.utils import num_trainable_params

import dataset
import models


def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        required=True,
                        type=str)

    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    parser.add_argument('--modelDir',
                        help='model directory',
                        type=str,
                        default='')
    parser.add_argument('--logDir',
                        help='log directory',
                        type=str,
                        default='')
    parser.add_argument('--dataDir',
                        help='data directory',
                        type=str,
                        default='')
    parser.add_argument('--prevModelDir',
                        help='prev Model directory',
                        type=str,
                        default='')

    args = parser.parse_args()

    return args

def _make_model(cfg, final_output_dir, is_train):
    """Initialise models
    Input:
        cfg: config object
    Returns:
        models: dictionary with models
    """

    if cfg.LOSS.RECONSTRUCTION:
        #Use the model from previous work
        pe_model = models.imm_model.AssembleNet(in_channels=3, n_filters=32, n_maps=cfg.MODEL.NUM_JOINTS, gauss_std=0.1, \
            renderer_stride=2, n_render_filters=32, n_final_out=3, \
                                max_size=cfg.MODEL.IMAGE_SIZE, min_size=cfg.MODEL.HEATMAP_SIZE)
        print('Initiali reconstruction model with {} params'.format(num_trainable_params(pe_model)))
    else:
        pe_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(cfg, is_train=is_train)

    models_dict = {'spen': pe_model}

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(os.path.join(this_dir, './lib/models', cfg.MODEL.NAME + '.py'), final_output_dir)

    if cfg.MODEL.KP_CLASS:
        kp_class = models.kp_class.KpClassNet(cfg.MODEL.NUM_FEAT,
                                              cfg.MODEL.KP_EMB,
                                               cfg.MODEL.KP_CLASS_NUM,
                                               cfg.MODEL.NUM_INTER_FEAT,
                                               cfg.MODEL.TUNE_HM,
                                               True)
        models_dict['kp_class'] = kp_class

    return models_dict

def _model_to_gpu(models_dict, cfg):
    if cfg.USE_GPU:
        for name in models_dict:
            models_dict[name] = torch.nn.DataParallel(models_dict[name], device_ids=cfg.GPUS).cuda()

    return models_dict


def _make_loss(cfg, logger):
    # define loss function (criterion) and optimizer
    loss_dict = {}

    if cfg.LOSS.SUPERVISED or 'supervised' in cfg.LOSS.LIST:
        loss_sup = JointsMSELoss(use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT)
        if cfg.USE_GPU: loss_sup = loss_sup.cuda()
        loss_dict['supervised'] = loss_sup

    if cfg.LOSS.CONSISTENCY or 'consistency' in cfg.LOSS.LIST:
        loss_cons = ConsistencyLoss(cfg.LOSS.CONSISTENCY_TYPE)
        if cfg.USE_GPU: loss_cons = loss_cons.cuda()
        loss_dict['consistency'] = loss_cons

    if cfg.LOSS.RECONSTRUCTION or 'reconstruction' in cfg.LOSS.LIST:
        loss_rec =  ConsistencyLoss(loss_type=cfg.LOSS.RECONSTRUCTION_TYPE)
        if cfg.USE_GPU: loss_rec = loss_rec.cuda()
        loss_dict['reconstruction'] = loss_rec

    if cfg.LOSS.KP_CLASS or 'kp_class' in cfg.LOSS.LIST:
        loss_kp_class = nn.CrossEntropyLoss()
        if cfg.USE_GPU: loss_kp_class = loss_kp_class.cuda()
        loss_dict['kp_class'] = loss_kp_class

    if cfg.LOSS.KP_CLASS_CONSISTENCY or 'kp_class_cons' in cfg.LOSS.LIST:
        loss_kp_cons = ConsistencyLoss(loss_type='mse')
        if cfg.USE_GPU: loss_kp_cons = loss_kp_cons.cuda()
        loss_dict['kp_class_cons'] = loss_kp_cons


    logger.info('=> initialised losses: {}'.format(list(loss_dict.keys())))

    return loss_dict

def _make_data(cfg, logger):
    """Initialise train and validation loaders as per config parameters    """
    train_transform = transforms.Compose([
                        dataset.transformers.Rotate(cfg.DATASET.ROT_FACTOR),
                        dataset.transformers.RandomHorizontalFlip(cfg.DATASET.FLIP_PROB, cfg.DATASET.SYMM_LDMARKS),
                        dataset.transformers.RandomCrop(cfg.MODEL.IMAGE_SIZE[0]),
                        dataset.transformers.ToTensor(),
                        dataset.transformers.Normalize(mean=[0.485, 0.456, 0.406],
                                                       std =[0.229, 0.224, 0.225])
                        ])

    valid_transform = transforms.Compose([
                        dataset.transformers.CenterCrop(cfg.MODEL.IMAGE_SIZE[0]),
                        dataset.transformers.ToTensor(),
                        dataset.transformers.Normalize(mean=[0.485, 0.456, 0.406],
                                                       std =[0.229, 0.224, 0.225])
                        ])

    train_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True, train_transform)

    valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False, valid_transform)

    #Read training indexes
    labeled_idxs, unlabeled_idxs = read_labelled_split(os.path.join(cfg.DATASET.ROOT,
                                                                    'annot', 'labels', cfg.DATASET.LABELS_SPLIT_FILE))
    logger.info('=> loaded {} labelled examples'.format(len(labeled_idxs)))
    logger.info('=> loaded {} unlabelled examples'.format(len(unlabeled_idxs)))

    if cfg.DATASET.LABELLED_ONLY:
        logger.info('=> supervised learning with {} labelled examples'.format(len(labeled_idxs)))
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, cfg.TRAIN.BS*len(cfg.GPUS), drop_last=True)
    elif cfg.DATASET.UNLABELLED_ONLY:
        logger.info('=> unsupervised learning with {} unlabelled examples'.format(len(unlabeled_idxs)+len(labeled_idxs)))
        sampler = SubsetRandomSampler(np.concatenate(labeled_idxs, unlabeled_idxs))
        batch_sampler = BatchSampler(sampler, cfg.TRAIN.BS*len(cfg.GPUS), drop_last=True)
    else:
        logger.info('=> semi-supervised learning')
        random.Random(4).shuffle(unlabeled_idxs)
        unlabeled_num = int(len(unlabeled_idxs) * cfg.TRAIN.UNLABELLED_PERCENTAGE)
        batch_sampler = dataset.TwoStreamBatchSampler(
            unlabeled_idxs[:unlabeled_num], labeled_idxs, cfg.TRAIN.BS*len(cfg.GPUS), cfg.TRAIN.BS_LABELLED*len(cfg.GPUS))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                                batch_sampler=batch_sampler,
                                                num_workers=cfg.WORKERS,
                                                pin_memory=cfg.PIN_MEMORY
                                            )

    #Validation set is the same for all cases
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                                batch_size=cfg.TEST.BS*len(cfg.GPUS),
                                                shuffle=False,
                                                num_workers=cfg.WORKERS,
                                                pin_memory=cfg.PIN_MEMORY
                                            )

    return train_loader, valid_loader, valid_dataset

def _make_optimizer(cfg, models_dict):

    models_param = list(models_dict['spen'].parameters())

    if cfg.MODEL.KP_CLASS:
        models_param += list(models_dict['kp_class'].parameters())

    optimizer = torch.optim.Adam(models_param, lr=cfg.TRAIN.LR)

    return optimizer


def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # Initialise models
    models_dict = _make_model(cfg, final_output_dir, True)

    # Initialise losses
    loss_dict = _make_loss(cfg, logger)

    # Initialise data loaders
    train_loader, valid_loader, valid_dataset = _make_data(cfg, logger)


    best_perf = 0.0
    is_best_model = False
    last_epoch = -1

    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(
        final_output_dir, 'checkpoint.pth'
    )

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        for model_name in models_dict:
            models_dict[model_name].load_state_dict(checkpoint['state_dict_'+model_name])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    models_dict = _model_to_gpu(models_dict, cfg)
    optimizer = _make_optimizer(cfg, models_dict)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):

        train(cfg, train_loader,
              models_dict,
              loss_dict,
              optimizer, epoch,
              final_output_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(
            cfg, valid_loader, valid_dataset,
            models_dict,
            loss_dict,
            final_output_dir, writer_dict
        )

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            is_best_model = True
        else:
            is_best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        checkpoint_dict = {
            'epoch': epoch + 1,
            'model': cfg.MODEL.NAME,
            'perf': perf_indicator,
            'optimizer': optimizer.state_dict(),
        }
        for model_name in models_dict:
            checkpoint_dict['state_dict_'+model_name] = models_dict[model_name].module.state_dict() if cfg.USE_GPU else models_dict[model_name].state_dict()

        torch.save(checkpoint_dict, os.path.join(final_output_dir, 'checkpoint.pth'))
        if is_best_model:
            logger.info('=> saving best model state to {} at epoch {}'.format(final_output_dir, epoch))
            for model_name in models_dict:
                torch.save(checkpoint_dict['state_dict_'+model_name], os.path.join(final_output_dir, model_name + '_best.pth'))

    logger.info('=> saving final model state to {}'.format(final_output_dir))
    for model_name in models_dict:
        torch.save(checkpoint_dict['state_dict_'+model_name], os.path.join(final_output_dir, model_name + '_final.pth'))

    writer_dict['writer'].close()


if __name__ == '__main__':
    main()
