"""
This script is used to train the two-passway model.
the model is consist of one teacher network and one student network.

when not using cache model,
the teacher is trained independently.
the teacher model can teaches the student model iteratively.

when using cache model,
the teacher is trianed under the effect of

"""

import logging
import os
from argparse import ArgumentParser
from collections import OrderedDict
import pdb
from functools import partial

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms

import sys
sys.path.append("../")
from mmcv import Config
from mmcv.runner import DistSamplerSeedHook, Runner
from dataset.image_transforms import Low_pass_filter, GaussianNoise
from mmcv.runner import load_state_dict
from dataset.cifar10_corrupted import CIFAR10_C_twopass
import models as network
from models import CacheMemory
from models import DistillKL

from losses.crd import CRDLoss

sys.path.append("./tools")
from utils import*
from update_cache_hook import UpdateCacheHook

cuda = torch.cuda.is_available()

def parse_args():
    parser = ArgumentParser(description='Train CIFAR-10 classification')
    parser.add_argument('config', help='train config file path')

    parser.add_argument('--model_name', type=str, default="CORnet_TwopassA")
    parser.add_argument('--load_from',type=str,default=None)
    parser.add_argument('--seed', type=int, default=100)

    #### fine model dict
    parser.add_argument('--fine_times', type=int , default=2)
    parser.add_argument('--fine_top_layer',type=str,default="IT")
    parser.add_argument('--fine_low_layer',type=str,default="V2")
    parser.add_argument('--feedback_mode',type=str,default="upsample_pconv_gate")
    parser.add_argument('--feedback_time',type=int,default=0)

    #### coarse model dict
    parser.add_argument('--coarse_low_pass_value', type=float, default=2.0)
    parser.add_argument('--coarse_kernel_type',type=str,default="large")

    #### noise type of dataset
    parser.add_argument('--corrupted_name', type=str, default="clean")
    parser.add_argument('--severity', type=int, default=1,help="severity is 1 to 5")
    parser.add_argument('--width', type=float, default=0,help="this value is in [0,1]")

    parser.add_argument('--work_dir', type=str, default="../twopass_res/cifar",help="identity of two-pass model.")
    ####  for cache memory
    parser.add_argument('--cache_T', type=float, default=0.01)
    parser.add_argument('--cache_mode', type=str, default="average", help="average or max")
    parser.add_argument('--using_cache', action="store_true", default=False, help="using cache or not")

    parser.add_argument('--distill_ratio', type=float, default=0.4)
    parser.add_argument('--distill_T', type=float, default=3)
    parser.add_argument('--distill_mode', type=str, default="kn_a", help="kn_a,kn_b,crd not complemented!")

    ### update model
    parser.add_argument('--update_cache_interval',type=int, default=2)
    parser.add_argument('--start_update',type=int, default=0,help="which epoch start to update the cache memory")
    parser.add_argument('--add_noise',action="store_true",default=False)
    parser.add_argument('--noise_ratio',type=float, default=0.)


    return parser.parse_args()

def batch_processor(model, data, train_mode, args, distill_model=None,cache_labels = None,**kwargs):
    fine_img, coarse_img, label = data
    label = label.long()
    if cuda:
        label = label.cuda(non_blocking=True)

    if args.load_from is not None:
        ## this script is used for distillation.
        ## fixing the teacher network and training the student network, this is for offline distillation.
        fine_training = False
    else:
        fine_training = True

    fine_pred, coarse_pred, (fine_hids, coarse_hids,cache_pred) = model([fine_img,coarse_img],fine_training=fine_training,coarse_training=True,
                                                             feedback_time=args.feedback_time, pre_feats=True)

    coarse_cross_loss = F.cross_entropy(coarse_pred, label)
    if args.distill_mode == "kn_a":
        coarse_distill_loss = distill_model(coarse_pred, fine_pred.detach())
        coarse_loss = coarse_cross_loss*(1-args.distill_ratio) + args.distill_ratio*coarse_distill_loss
    elif args.distill_mode == "kn_b":
        fine_feat = fine_hids[-4]  ###[it_out1,out_pool1,pred1,it_out2,out_pool2]
        coarse_feat = coarse_hids[-1]
        coarse_distill_loss = F.mse_loss(coarse_feat,fine_feat.detach(),reduction='mean')
        coarse_loss = coarse_cross_loss + args.distill_ratio*coarse_distill_loss
    else:
        raise ValueError("no such distill mode !")

    if args.using_cache and model.module.update_flag:
        cache_labels = one_hot(np.array(cache_labels),10)
        cache_pred = torch.matmul(cache_pred,cache_labels.cuda())
        coarse_cache_acc_top1,_ = accuracy(cache_pred,label,topk=(1,5))

    fine_loss = F.cross_entropy(fine_pred, label)
    loss = fine_loss + coarse_loss

    fine_acc_top1, _ = accuracy(fine_pred, label, topk=(1, 5))
    coarse_acc_top1, _ = accuracy(coarse_pred, label, topk=(1, 5))

    log_vars = OrderedDict()
    log_vars['loss'] = loss.item()
    log_vars['coarse_cross_loss'] = coarse_cross_loss.item()*(1-args.distill_ratio)
    log_vars['coarse_distill_loss'] = coarse_distill_loss.item()*args.distill_ratio
    log_vars['fine_loss'] = fine_loss.item()
    log_vars['fine_acc_top1'] = fine_acc_top1.item()
    log_vars['coarse_acc_top1'] = coarse_acc_top1.item()

    if args.using_cache and model.module.update_flag:
        log_vars['coarse_cache_acc_top1'] = coarse_cache_acc_top1.item()

    outputs = dict(loss=loss,log_vars=log_vars, num_samples=fine_img.size(0))
    return outputs

def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    set_seed(args.seed)

    cfg.model = args.model_name
    cfg.work_dir = args.work_dir

    ### fine_model_dict
    cfg.model_dict.fine_model_dict["times"] = args.fine_times
    cfg.model_dict.fine_model_dict["top_layer"] = args.fine_top_layer
    cfg.model_dict.fine_model_dict["low_layer"] = args.fine_low_layer
    cfg.model_dict.fine_model_dict["feedback_mode"] = args.feedback_mode

    ### coarse_model_dict
    cfg.model_dict.coarse_model_dict["kernel_type"] = args.coarse_kernel_type

    ### cache_model_dict
    cfg.model_dict.cache_model_dict["mode"] = args.cache_mode
    cfg.model_dict.cache_model_dict["T"] = args.cache_T

    logger = get_logger(cfg.log_level)

    # if args.load_from is not None:
    #     source_data = np.load(args.work_dir+"/source_data.npy")
    #     target_data = np.load(args.work_dir+"/target_data.npy")
    #     cfg.model_dict.source_data = source_data
    #     cfg.model_dict.target_data = target_data

    # init distributed environment if necessary
    # build datasets and dataloaders
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]

    num_workers = cfg.data_workers * len(cfg.gpus)
    batch_size = cfg.batch_size
    train_sampler = None
    val_sampler = None

    fine_train_transform_dit = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        GaussianNoise(args.width),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    fine_val_transform_dit = transforms.Compose([
        GaussianNoise(args.width),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    if args.coarse_low_pass_value != 0:
        coarse_train_transform_dict = transforms.Compose([
           transforms.RandomCrop(32, padding=4),
           transforms.RandomHorizontalFlip(),
            GaussianNoise(args.width),
            Low_pass_filter(args.coarse_low_pass_value,n_channels=1),
			transforms.ToTensor(),
			transforms.Normalize(mean=[0.5,], std=[0.2,]),
		])

        coarse_val_transform_dict = transforms.Compose([
            GaussianNoise(args.width),
            Low_pass_filter(args.coarse_low_pass_value,n_channels=1),
			transforms.ToTensor(),
			transforms.Normalize(mean=[0.5,], std=[0.2,]),
		])

        cfg.model_dict.coarse_model_dict["in_channels"] = 1

    else:
        coarse_train_transform_dict = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            GaussianNoise(args.width),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean, std=std),
		])

        coarse_val_transform_dict = transforms.Compose([
            GaussianNoise(args.width),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean, std=std),
		])
        cfg.model_dict.coarse_model_dict["in_channels"] = 3

    train_dataset = CIFAR10_C_twopass(
        corrupted_name=args.corrupted_name,
        severity=args.severity,
        train = True,
		coarse_transform=coarse_train_transform_dict,
        fine_transform=fine_train_transform_dit)

    cache_dataset = CIFAR10_C_twopass(
        corrupted_name=args.corrupted_name,
        severity=args.severity,
        train = True,
		coarse_transform=coarse_val_transform_dict,
        fine_transform=fine_val_transform_dit)


    val_dataset = CIFAR10_C_twopass(
        corrupted_name=args.corrupted_name,
        severity=args.severity,
        train = False,
		coarse_transform=coarse_val_transform_dict,
        fine_transform=fine_val_transform_dit)



    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        sampler=train_sampler,
        num_workers=num_workers)

    cache_loader = DataLoader(
        cache_dataset,
        batch_size=1024,
        shuffle = False,
        sampler=train_sampler,
        num_workers=num_workers)

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        sampler=val_sampler,
        num_workers=num_workers)

    print("cache dataset shape is, ", cache_dataset.test_data.shape)
    print("train_dataset dataset shape is, ", train_dataset.test_data.shape)
    print("val_loader dataset shape is, ", val_dataset.test_data.shape)
    # build model
    cfg.model_dict.dataloader = cache_loader ##cache_loader
    cfg.model_dict.using_cache = args.using_cache

    model = getattr(network, cfg.model)(**cfg.model_dict)

    if args.load_from is not None:
        model_pth = torch.load(args.load_from)
        model_state_dict = filter_dict(model_pth["state_dict"],"finenet")
        load_state_dict(model,model_state_dict,strict=False,logger=None)

    print("model has been loaded")

    if cuda:
        model = DataParallel(model, device_ids=cfg.gpus).cuda()
    else:
        model = DataParallel(model, device_ids=cfg.gpus)

    # if args.using_cache:
    #     model.module.update_cachememory()
    #     np.save(args.work_dir+"/source_data.npy",model.module.source_data.cpu().numpy())
    #     np.save(args.work_dir+"/target_data.npy",model.module.target_data.cpu().numpy())

    distill_model = DistillKL(args.distill_T)
    batch_processor1 = partial(batch_processor,args=args,distill_model=distill_model,cache_labels=cache_dataset.test_labels)

    runner = Runner(
        model,
        batch_processor1,
        cfg.optimizer,
        cfg.work_dir,
        log_level=cfg.log_level)
    runner.register_training_hooks(
        lr_config=cfg.lr_config,
        optimizer_config=cfg.optimizer_config,
        checkpoint_config=cfg.checkpoint_config,
        log_config=cfg.log_config)

    if args.using_cache:
        runner.register_hook(UpdateCacheHook(args.update_cache_interval, args.start_update))

    # #### temperally used
    # resume_from = args.work_dir+"/latest.pth"
    # runner.resume(resume_from)

    runner.run([train_loader, val_loader], cfg.workflow, cfg.total_epochs)
    runner.val(val_loader)
    runner.log_buffer.average()
    if args.corrupted_name in ["gaussian_noise2"]:
        args.severity = args.width ## just for my noise type , gaussian_noise2

    txt_dir = os.path.join("save_distil",args.work_dir.split("/")[-1])
    os.makedirs(txt_dir,exist_ok=True)
    np.savetxt(txt_dir+"/T_{}_{}_{}_fine.txt".format(args.cache_T,args.corrupted_name,args.severity),[runner.log_buffer.output['fine_acc_top1']])
    np.savetxt(txt_dir+"/T_{}_{}_{}_coarse.txt".format(args.cache_T,args.corrupted_name,args.severity),[runner.log_buffer.output['coarse_acc_top1']])
    if args.using_cache:
        np.savetxt(txt_dir+"/T_{}_{}_{}_cache.txt".format(args.cache_T,args.corrupted_name,args.severity),[runner.log_buffer.output['coarse_cache_acc_top1']])

if __name__ == '__main__':
    main()
