"""
This script is used to train the two-passway model.
the model is consist of one teacher network and one student network.

when not using cache model,
the teacher model can teaches the student model iteratively.

when using cache model,
the teacher model can be trained with or without noise.
the noise is added to the intermedia layer of the feedback low layer,
and can be used to force the feedback network to take the advantage of
feedback connections.
"""

import logging
import os
from argparse import ArgumentParser
from collections import OrderedDict
import pdb
from functools import partial

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms

import sys
sys.path.append("../")
from mmcv import Config
from mmcv.runner import DistSamplerSeedHook, Runner
from dataset.image_transforms import Low_pass_filter, GaussianNoise
from mmcv.runner import load_state_dict
from dataset.cifar10_corrupted import CIFAR10_C_twopass
import models as network
from models import CacheMemory
from models import DistillKL

from losses.crd import CRDLoss

sys.path.append("../tools")
from utils import*
from update_cache_hook import UpdateCacheHook

cuda = torch.cuda.is_available()

def parse_args():
    parser = ArgumentParser(description='Train CIFAR-10 classification')
    parser.add_argument('config', help='train config file path')

    parser.add_argument('--model_name', type=str, default="CORnet_TwopassB")
    parser.add_argument('--load_from',type=str,default=None)
    parser.add_argument('--seed', type=int, default=100)

    #### coarse model dict
    parser.add_argument('--coarse_low_pass_value', type=float, default=3.0)
    parser.add_argument('--coarse_kernel_type',type=str,default="large")

    #### noise type of dataset
    parser.add_argument('--corrupted_name', type=str, default="clean")
    parser.add_argument('--severity', type=int, default=1,help="severity is 1 to 5")
    parser.add_argument('--width', type=float, default=0,help="this value is in [0,1]")

    parser.add_argument('--work_dir', type=str, default="../twopass_res/rbm_cifar",help="identity of two-pass model.")
    ####  for cache memory

    parser.add_argument('--distill_fused', type=float, default=0.)
    parser.add_argument('--distill_ensem', type=float, default=0.)
    parser.add_argument('--distill_T', type=float, default=5)

    return parser.parse_args()

def batch_processor(model, data, train_mode, args, distill_model=None,**kwargs):
    fine_img, coarse_img, label = data
    label = label.long()
    if cuda:
        label = label.cuda(non_blocking=True)

    pred, fine_pred, coarse_pred, feats = model([fine_img,coarse_img],fine_training=True,coarse_training=True,
                                   pre_feats=True)

    ensemble_pred = (fine_pred+coarse_pred)/2

    coarse_cross_loss = F.cross_entropy(coarse_pred, label)
    fused_cross_loss = F.cross_entropy(pred, label)
    fine_cross_loss = F.cross_entropy(fine_pred, label)
    ensemble_cross_loss = F.cross_entropy(ensemble_pred, label)

    cross_loss = ensemble_cross_loss + fine_cross_loss + fused_cross_loss + coarse_cross_loss

    if args.distill_fused > 0 :
        fused_distill_loss = distill_model(pred, ensemble_pred.detach())
    else:
        fused_distill_loss = torch.FloatTensor([0]).to("cuda")

    if args.distill_ensem > 0 :
        fine_distill_loss = distill_model(fine_pred, pred.detach())
        coarse_distill_loss = distill_model(coarse_pred, pred.detach())
        ensem_distill_loss = fine_distill_loss +  coarse_distill_loss
    else:
        ensem_distill_loss = torch.FloatTensor([0]).to("cuda")

    distill_loss = args.distill_ensem*ensem_distill_loss + args.distill_fused*fused_distill_loss
    loss = cross_loss + distill_loss

    fine_acc_top1, _ = accuracy(fine_pred, label, topk=(1, 5))
    coarse_acc_top1, _ = accuracy(coarse_pred, label, topk=(1, 5))
    ensemble_acc_top1, _ = accuracy(pred, label, topk=(1, 5))

    log_vars = OrderedDict()
    log_vars['loss'] = loss.item()
    log_vars['cross_loss'] = cross_loss.item()
    log_vars['distill_loss'] = distill_loss.item()
    log_vars['fine_acc_top1'] = fine_acc_top1.item()
    log_vars['coarse_acc_top1'] = coarse_acc_top1.item()
    log_vars['ensemble_acc_top1'] = ensemble_acc_top1.item()

    outputs = dict(loss=loss,log_vars=log_vars, num_samples=fine_img.size(0))
    return outputs

def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    set_seed(args.seed)

    cfg.model = args.model_name
    cfg.work_dir = args.work_dir

    ### coarse_model_dict
    cfg.model_dict.coarse_model_dict["kernel_type"] = args.coarse_kernel_type

    logger = get_logger(cfg.log_level)

    # if args.load_from is not None:
    #     source_data = np.load(args.work_dir+"/source_data.npy")
    #     target_data = np.load(args.work_dir+"/target_data.npy")
    #     cfg.model_dict.source_data = source_data
    #     cfg.model_dict.target_data = target_data

    # init distributed environment if necessary
    # build datasets and dataloaders
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]

    num_workers = cfg.data_workers * len(cfg.gpus)
    batch_size = cfg.batch_size
    train_sampler = None
    val_sampler = None

    fine_train_transform_dit = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        GaussianNoise(args.width),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    fine_val_transform_dit = transforms.Compose([
        GaussianNoise(args.width),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    if args.coarse_low_pass_value != 0:
        coarse_train_transform_dict = transforms.Compose([
           transforms.RandomCrop(32, padding=4),
           transforms.RandomHorizontalFlip(),
            GaussianNoise(args.width),
            Low_pass_filter(args.coarse_low_pass_value,n_channels=1),
			transforms.ToTensor(),
			transforms.Normalize(mean=[0.5,], std=[0.2,]),
		])

        coarse_val_transform_dict = transforms.Compose([
            GaussianNoise(args.width),
            Low_pass_filter(args.coarse_low_pass_value,n_channels=1),
			transforms.ToTensor(),
			transforms.Normalize(mean=[0.5,], std=[0.2,]),
		])

        cfg.model_dict.coarse_model_dict["in_channels"] = 1

    else:
        coarse_train_transform_dict = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            GaussianNoise(args.width),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean, std=std),
		])

        coarse_val_transform_dict = transforms.Compose([
            GaussianNoise(args.width),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean, std=std),
		])
        cfg.model_dict.coarse_model_dict["in_channels"] = 3

    train_dataset = CIFAR10_C_twopass(
        corrupted_name=args.corrupted_name,
        severity=args.severity,
        train = True,
		coarse_transform=coarse_train_transform_dict,
        fine_transform=fine_train_transform_dit)

    val_dataset = CIFAR10_C_twopass(
        corrupted_name=args.corrupted_name,
        severity=args.severity,
        train = False,
		coarse_transform=coarse_val_transform_dict,
        fine_transform=fine_val_transform_dit)


    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        sampler=train_sampler,
        num_workers=num_workers)

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        sampler=val_sampler,
        num_workers=num_workers)

    model = getattr(network, cfg.model)(**cfg.model_dict)

    if args.load_from is not None:
        model_pth = torch.load(args.load_from)
        model_state_dict = filter_dict(model_pth["state_dict"],"finenet")
        load_state_dict(model,model_state_dict,strict=False,logger=None)

    print("model has been loaded")

    if cuda:
        model = DataParallel(model, device_ids=cfg.gpus).cuda()
    else:
        model = DataParallel(model, device_ids=cfg.gpus)

    # if args.distill_mode == "kn_a":
    distill_model = DistillKL(args.distill_T)
    batch_processor1 = partial(batch_processor,args=args,distill_model=distill_model)

    runner = Runner(
        model,
        batch_processor1,
        cfg.optimizer,
        cfg.work_dir,
        log_level=cfg.log_level)

    runner.register_training_hooks(
        lr_config=cfg.lr_config,
        optimizer_config=cfg.optimizer_config,
        checkpoint_config=cfg.checkpoint_config,
        log_config=cfg.log_config)

    runner.run([train_loader, val_loader], cfg.workflow, cfg.total_epochs)
    runner.val(val_loader)

    runner.log_buffer.average()
    if args.corrupted_name in ["gaussian_noise2"]:
        args.severity = args.width ## just for my noise type , gaussian_noise2

    txt_dir = os.path.join("save_ensemble",args.work_dir.split("/")[-1])
    os.makedirs(txt_dir,exist_ok=True)
    np.savetxt(txt_dir+"/T_{}_{}_fine.txt".format(args.corrupted_name,args.severity),[runner.log_buffer.output['fine_acc_top1']])
    np.savetxt(txt_dir+"/T_{}_{}_coarse.txt".format(args.corrupted_name,args.severity),[runner.log_buffer.output['coarse_acc_top1']])
    np.savetxt(txt_dir+"/T_{}_{}_ensemble.txt".format(args.corrupted_name,args.severity),[runner.log_buffer.output['ensemble_acc_top1']])

if __name__ == '__main__':
    main()
