"""
This script is used to train the two-passway model.
the model is consist of one teacher network and one student network.

when not using cache model,
the teacher model can teaches the student model iteratively.

when using cache model,
the teacher model can be trained with or without noise.
the noise is added to the intermedia layer of the feedback low layer,
and can be used to force the feedback network to take the advantage of
feedback connections.
"""

import logging
import os
from argparse import ArgumentParser
from collections import OrderedDict
import pdb
from functools import partial

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms

import sys
sys.path.append("../")
from mmcv import Config
from mmcv.runner import DistSamplerSeedHook, Runner
from dataset.image_transforms import Low_pass_filter, GaussianNoise
from mmcv.runner import load_state_dict
from dataset.cifar10_corrupted import CIFAR10_C_twopass
import models as network
from models import CacheMemory
from models import DistillKL

sys.path.append("../tools")
from utils import*
from update_cache_hook import UpdateCacheHook

cuda = torch.cuda.is_available()

def parse_args():
    parser = ArgumentParser(description='Train CIFAR-10 classification')
    parser.add_argument('config', help='train config file path')

    parser.add_argument('--model_name', type=str, default="CORnet_Twopass")
    parser.add_argument('--load_from',type=str,default=None)
    parser.add_argument('--seed', type=int, default=100)

    #### fine model dict
    parser.add_argument('--fine_times', type=int , default=1)
    parser.add_argument('--fine_top_layer',type=str,default="IT")
    parser.add_argument('--fine_low_layer',type=str,default="V1")
    parser.add_argument('--feedback_time',type=int,default=0)

    #### coarse model dict
    parser.add_argument('--coarse_low_pass_value', type=float, default=3.0)
    parser.add_argument('--coarse_kernel_type',type=str,default="large")

    #### noise type of dataset
    parser.add_argument('--corrupted_name', type=str, default="clean")
    parser.add_argument('--severity', type=int, default=1,help="severity is 1 to 5")
    parser.add_argument('--width', type=float, default=0,help="this value is in [0,1]")


    parser.add_argument('--work_dir', type=str, default="../twopass_res/rbm_cifar",help="identity of two-pass model.")
    ####  for cache memory
    parser.add_argument('--cache_T', type=float, default=0.01)
    parser.add_argument('--cache_mode', type=str, default="average", help="average or max")
    parser.add_argument('--using_cache', action="store_true", default=False, help="using cache or not")

    parser.add_argument('--distill_ratio', type=float, default=0.4)
    parser.add_argument('--distill_T', type=float, default=5)

    ### update model
    parser.add_argument('--update_cache_interval',type=int, default=2)
    parser.add_argument('--add_noise',action="store_true",default=False)
    parser.add_argument('--noise_ratio',type=float, default=0.0)


    return parser.parse_args()

def batch_processor(model, data, train_mode, args, cache_labels = None):
    fine_img, coarse_img, label = data
    label = label.long()
    if cuda:
        label = label.cuda(non_blocking=True)

    # if args.update_cache_interval
    fine_pred, coarse_pred,(fine_hids, coarse_hids,cache_pred) = model([fine_img,coarse_img],fine_training=True,
                          coarse_training=True,feedback_time=args.feedback_time,pre_feats=True)

    fine_acc_top1, _ = accuracy(fine_pred, label, topk=(1, 5))
    coarse_acc_top1, _ = accuracy(coarse_pred, label, topk=(1, 5))

    # pdb.set_trace()

    if args.using_cache:
        # pdb.set_trace()
        cache_labels = one_hot(np.array(cache_labels),10)

        cache_pred = torch.matmul(cache_pred,cache_labels.cuda())
        # pdb.set_trace()
        coarse_cache_acc_top1,_ = accuracy(cache_pred,label,topk=(1,5))

    if args.corrupted_name == 'adversarial_noise':
        fine_solo_pred,_ = model.module.finenet(fine_img.cuda())
        finesolo_acc_top1, _ = accuracy(fine_solo_pred, label, topk=(1, 5))




    log_vars = OrderedDict()
    # log_vars['loss'] = loss.item()
    # log_vars['coarse_cross_loss'] = coarse_cross_loss.item()*(1-args.distill_ratio)
    # log_vars['coarse_distill_loss'] = coarse_distill_loss.item()*args.distill_ratio
    # log_vars['fine_loss'] = fine_loss.item()
    log_vars['fine_acc_top1'] = fine_acc_top1.item()
    log_vars['coarse_acc_top1'] = coarse_acc_top1.item()
    if args.corrupted_name == 'adversarial_noise':
        log_vars['finesolo_acc_top1'] = finesolo_acc_top1.item()
    if args.using_cache:
        log_vars['coarse_cache_acc_top1'] = coarse_cache_acc_top1.item()

    outputs = dict(log_vars=log_vars, num_samples=fine_img.size(0))
    return outputs

def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    set_seed(args.seed)

    cfg.model = args.model_name
    cfg.work_dir = args.work_dir

    ### fine_model_dict
    cfg.model_dict.fine_model_dict["times"] = args.fine_times
    cfg.model_dict.fine_model_dict["top_layer"] = args.fine_top_layer
    cfg.model_dict.fine_model_dict["low_layer"] = args.fine_low_layer

    ### coarse_model_dict
    cfg.model_dict.coarse_model_dict["kernel_type"] = args.coarse_kernel_type

    ### cache_model_dict
    cfg.model_dict.cache_model_dict["mode"] = args.cache_mode
    cfg.model_dict.cache_model_dict["T"] = args.cache_T

    logger = get_logger(cfg.log_level)

    if args.using_cache:
        source_data = np.load(args.work_dir+"/source_data.npy")
        target_data = np.load(args.work_dir+"/target_data.npy")
        cfg.model_dict.source_data = source_data
        cfg.model_dict.target_data = target_data

    # init distributed environment if necessary
    # build datasets and dataloaders
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]

    num_workers = cfg.data_workers * len(cfg.gpus)
    batch_size = cfg.batch_size
    train_sampler = None
    val_sampler = None

    fine_transform_dit = transforms.Compose([
        GaussianNoise(args.width),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    if args.coarse_low_pass_value != 0:
        coarse_transform_dict = transforms.Compose([
            GaussianNoise(args.width),
            Low_pass_filter(args.coarse_low_pass_value,n_channels=1),
			transforms.ToTensor(),
			transforms.Normalize(mean=[0.5,], std=[0.2,]),
		])

        cfg.model_dict.coarse_model_dict["in_channels"] = 1

    else:
        coarse_transform_dict = fine_transform_dit = transforms.Compose([
            GaussianNoise(args.width),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean, std=std),
		])
        cfg.model_dict.coarse_model_dict["in_channels"] = 3

    cache_dataset = CIFAR10_C_twopass(
        corrupted_name="clean",
        severity=args.severity,
        train = True,   ### True
		coarse_transform=coarse_transform_dict,
        fine_transform=fine_transform_dit)

    val_dataset = CIFAR10_C_twopass(
        corrupted_name=args.corrupted_name,
        severity=args.severity,
        train = False,
        work_dir = args.work_dir,
		coarse_transform=coarse_transform_dict,
        fine_transform=fine_transform_dit)

    cache_loader = DataLoader(
        cache_dataset,
        batch_size=1024,
        shuffle=False,
        sampler=val_sampler,
        num_workers=num_workers)


    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        sampler=val_sampler,
        num_workers=num_workers)

    # build model
    cfg.model_dict.dataloader = cache_loader  ##cache_loader
    cfg.model_dict.using_cache = args.using_cache

    model = getattr(network, cfg.model)(**cfg.model_dict)

    model_state_dict = torch.load(args.load_from)
    load_state_dict(model,model_state_dict["state_dict"],strict=True,logger=None)

    print("model has been loaded")

    if cuda:
        model = DataParallel(model, device_ids=cfg.gpus).cuda()
    else:
        model = DataParallel(model, device_ids=cfg.gpus)

    # if args.using_cache:
    #     model.module.update_cachememory()
    #     np.save(args.work_dir+"/source_data.npy",model.module.source_data.cpu().numpy())
    #     np.save(args.work_dir+"/target_data.npy",model.module.target_data.cpu().numpy())

    # build runner and register hooks
    batch_processor1 = partial(batch_processor,args=args,cache_labels=cache_dataset.test_labels)

    runner = Runner(
        model,
        batch_processor1,
        cfg.optimizer,
        cfg.work_dir,
        log_level=cfg.log_level)
    runner.register_training_hooks(
        lr_config=cfg.lr_config,
        optimizer_config=cfg.optimizer_config,
        checkpoint_config=cfg.checkpoint_config,
        log_config=cfg.log_config)


    runner.val(val_loader)
    runner.log_buffer.average()
    if args.corrupted_name in ["gaussian_noise2"]:
        noise_strength = args.width ## just for my noise type , gaussian_noise2
    elif args.corrupted_name in ["clean"]:
        noise_strength = 0
    else:
        noise_strength = args.severity

    txt_dir = os.path.join("save",args.work_dir.split("/")[-1])
    os.makedirs(txt_dir,exist_ok=True)
    np.savetxt(txt_dir+"/T_{}_{}_{}_fine.txt".format(args.cache_T,args.corrupted_name,noise_strength),[runner.log_buffer.output['fine_acc_top1']])
    np.savetxt(txt_dir+"/T_{}_{}_{}_coarse.txt".format(args.cache_T,args.corrupted_name,noise_strength),[runner.log_buffer.output['coarse_acc_top1']])
    if args.corrupted_name == 'adversarial_noise':
        np.savetxt(txt_dir+"/T_{}_{}_{}_finesolo.txt".format(args.cache_T,args.corrupted_name,noise_strength),[runner.log_buffer.output['finesolo_acc_top1']])
    if args.using_cache:
        np.savetxt(txt_dir+"/T_{}_{}_{}_cache.txt".format(args.cache_T,args.corrupted_name,noise_strength),[runner.log_buffer.output['coarse_cache_acc_top1']])

if __name__ == '__main__':
    main()
