import logging
import os
from argparse import ArgumentParser
from collections import OrderedDict
import pdb
from functools import partial


import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms

import sys
sys.path.append("../")
from mmcv import Config
from mmcv.runner import DistSamplerSeedHook, Runner
from dataset.image_transforms import Low_pass_filter

from dataset.cifar10_corrupted import CIFAR10_C_twopass
from utils import*
import models as network

cuda = torch.cuda.is_available()

def parse_args():
    parser = ArgumentParser(description='Train CIFAR-10 classification')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--model_name', type=str, default="VVSNet")
    parser.add_argument('--work_dir', type=str, default="../res/rbm_cifar")
    parser.add_argument('--load_from',type=str,default=None)
    parser.add_argument('--corrupted_name', type=str, default="clean")
    parser.add_argument('--low_pass_value', type=float , default=0)
    parser.add_argument('--times', type=int , default=1)
    parser.add_argument('--interact_times', type=int , default=0)
    parser.add_argument('--severity', type=int, default=1,help="severity is 1 to 5")
    return parser.parse_args()

def batch_processor(model, data, train_mode, args):
    fine_img,coarse_img, label = data
    label = label.long()

    if cuda:
        label = label.cuda(non_blocking=True)
    pred_fine,pred_coarse = model((fine_img,coarse_img))

    loss_fine = F.cross_entropy(pred_fine, label)
    fine_acc_top1, fine_acc_top5 = accuracy(pred_fine, label, topk=(1, 5))
    loss_coarse = F.cross_entropy(pred_coarse, label)
    coarse_acc_top1, coarse_acc_top5 = accuracy(pred_coarse, label, topk=(1, 5))

    loss = loss_fine + loss_coarse

    log_vars = OrderedDict()
    log_vars['loss'] = loss.item()
    log_vars['loss_fine'] = loss_fine.item()
    log_vars['loss_coarse'] = loss_coarse.item()

    log_vars['coarse_acc_top1'] = coarse_acc_top1.item()
    log_vars['coarse_acc_top5'] = coarse_acc_top5.item()
    log_vars['fine_acc_top1'] = fine_acc_top1.item()
    log_vars['fine_acc_top5'] = fine_acc_top5.item()

    outputs = dict(loss=loss, log_vars=log_vars, num_samples=fine_img.size(0))
    return outputs

def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    cfg.model = args.model_name
    cfg.model_dict.times = args.times
    cfg.model_dict.interact_times = args.interact_times

    cfg.load_from = args.load_from
    logger = get_logger(cfg.log_level)

    # init distributed environment if necessary
    # build datasets and dataloaders
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]

    num_workers = cfg.data_workers * len(cfg.gpus)
    batch_size = cfg.batch_size
    train_sampler = None
    val_sampler = None

    if args.low_pass_value !=0:
        coarse_transform_dict = transforms.Compose([
            Low_pass_filter(args.low_pass_value,n_channels=1),
			transforms.ToTensor(),
			transforms.Normalize(mean=[0.5,], std=[0.2,]),
		])
        cfg.model_dict.coarse_in_channels = 1

    else:
        coarse_transform_dict=transforms.Compose([
			transforms.ToTensor(),
			transforms.Normalize(mean=mean, std=std),
		])
        cfg.model_dict.coarse_in_channels = 3

    fine_transform_dict = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])


    val_dataset = CIFAR10_C_twopass(
		corrupted_name = args.corrupted_name,
        severity = args.severity,
        coarse_transform = coarse_transform_dict,
        fine_transform = fine_transform_dict,
		)

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        sampler=val_sampler,
        num_workers=num_workers)

    # build model
    model = getattr(network, args.model_name)(**cfg.model_dict)
    # model.load_state_dict(torch.load(cfg.load_from)['model'])

    print("model has been loaded")

    if cuda:
        model = DataParallel(model, device_ids=cfg.gpus).cuda()
    else:
        model = DataParallel(model, device_ids=cfg.gpus)

    # build runner and register hooks
    batch_processor1 = partial(batch_processor,args=args)
    runner = Runner(
        model,
        batch_processor1,
        cfg.optimizer,
        cfg.work_dir,
        log_level=cfg.log_level)
    runner.register_training_hooks(
        lr_config=cfg.lr_config,
        optimizer_config=cfg.optimizer_config,
        checkpoint_config=cfg.checkpoint_config,
        log_config=cfg.log_config)

    if cfg.get('resume_from') is not None:
        runner.resume(cfg.resume_from)
    elif cfg.get('load_from') is not None:
        runner.load_checkpoint(cfg.load_from)
    ### different noise
    runner.val(val_loader)

    txt_dir = os.path.join(args.work_dir,"noise_txt")
    os.makedirs(txt_dir,exist_ok=True)
    np.savetxt(txt_dir+"/{}_{}.txt".format(args.corrupted_name,args.severity),[runner.log_buffer.output['fine_acc_top1']])


if __name__ == '__main__':
    main()
