import logging
import os
from argparse import ArgumentParser
from collections import OrderedDict
import pdb
from functools import partial


import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms

import sys
sys.path.append("../")
from mmcv import Config
from mmcv.runner import DistSamplerSeedHook, Runner

from utils import*
import models as network
from dataset.image_transforms import Low_pass_filter


cuda = torch.cuda.is_available()

def parse_args():
    parser = ArgumentParser(description='Train CIFAR-10 classification')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--model_name', type=str, default="SimFeedbackNet_b")
    parser.add_argument('--work_dir', type=str, default="../res/rbm_cifar")
    parser.add_argument('--load_from',type=str,default=None)
    parser.add_argument('--low_pass_value', type=float , default=0)
    parser.add_argument('--times', type=int , default=1)
    return parser.parse_args()

def batch_processor(model, data, train_mode, args):
    img, label = data
    label = label.long()

    if cuda:
        label = label.cuda(non_blocking=True)

    pred,_ = model(img,pred_class=True)

    # pdb.set_trace()

    loss = F.cross_entropy(pred, label)
    acc_top1, acc_top5 = accuracy(pred, label, topk=(1, 5))
    log_vars = OrderedDict()
    log_vars['loss'] = loss.item()
    log_vars['acc_top1'] = acc_top1.item()
    log_vars['acc_top5'] = acc_top5.item()

    outputs = dict(loss=loss, log_vars=log_vars, num_samples=img.size(0))
    return outputs

def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    cfg.model = args.model_name
    if 'times' in list(cfg.model_dict.keys()):
        cfg.model_dict.times = args.times
    cfg.work_dir = args.work_dir

    # cfg.load_from = args.load_from
    logger = get_logger(cfg.log_level)

    # init distributed environment if necessary
    # build datasets and dataloaders
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]

    num_workers = cfg.data_workers * len(cfg.gpus)
    batch_size = cfg.batch_size
    train_sampler = None
    val_sampler = None

    if args.low_pass_value != 0:
        n_channels=1
        transform_dict = transforms.Compose([
            Low_pass_filter(args.low_pass_value,n_channels=n_channels),
			transforms.ToTensor(),
			transforms.Normalize(mean=[0.5,], std=[0.2,]),
		])

        cfg.model_dict.in_channels = 1
    else:
        in_channels=3
        transform_dict = transforms.Compose([
			transforms.ToTensor(),
			transforms.Normalize(mean=mean, std=std),
		])
        cfg.model_dict.in_channels = 3


    train_dataset = datasets.CIFAR10(
        root="~/dataset/CIFAR10",
        train=True,
		transform=transform_dict)

    val_dataset = datasets.CIFAR10(
        root="~/dataset/CIFAR10",
        train=False,
		transform=transform_dict)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        sampler=train_sampler,
        num_workers=num_workers)

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        sampler=val_sampler,
        num_workers=num_workers)

    # build model
    model = getattr(network, args.model_name)(**cfg.model_dict)

    print("model has been loaded")

    if cuda:
        model = DataParallel(model, device_ids=cfg.gpus).cuda()
    else:
        model = DataParallel(model, device_ids=cfg.gpus)

    # build runner and register hooks
    batch_processor1 = partial(batch_processor,args=args)
    runner = Runner(
        model,
        batch_processor1,
        cfg.optimizer,
        cfg.work_dir,
        log_level=cfg.log_level)
    runner.register_training_hooks(
        lr_config=cfg.lr_config,
        optimizer_config=cfg.optimizer_config,
        checkpoint_config=cfg.checkpoint_config,
        log_config=cfg.log_config)

    ### different noise
    runner.run([train_loader, val_loader], cfg.workflow, cfg.total_epochs)
    runner.val(val_loader)

    txt_dir = os.path.join(args.work_dir,"acc_txt")
    os.makedirs(txt_dir,exist_ok=True)
    np.savetxt(txt_dir+"/std_{}_acc.txt".format(args.low_pass_value),[runner.log_buffer.output['acc_top1']])


if __name__ == '__main__':
    main()
