from __future__ import print_function
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
import os
import comet_ml
import argparse
import collections
import requests
import socket
from parse_config import ConfigParser
from collections import OrderedDict
import argparse
import numpy as np
from net import *
import dataloader_IDN as dataloader 
from utils import *
import loss as module_loss
import wandb
import time

# Training
def train(config, epoch, net, dataloader, optimizer, optimizer_uv, optimizer_trans, train_loss, P=None):
    net.train()
    k = config["train_loss"]["args"]["k"]
    eps = 0.001

    for batch_idx, (inputs_x, _, labels_x, index, true_labels) in enumerate(dataloader):
        batch_size = inputs_x.size(0)

        labels_x = torch.zeros(batch_size, config["num_classes"]).scatter_(1, labels_x.view(-1, 1), 1)
        inputs_x, index, labels_x = inputs_x.cuda(), index.cuda(), labels_x.cuda()
        outputs_x = net(inputs_x)

        loss, MSE_loss, CE_loss, vol_loss, Terror, det_T = train_loss(epoch, index, outputs_x, labels_x, batch_idx, P=P)

        optimizer.zero_grad()
        optimizer_uv.zero_grad()
        optimizer_trans.zero_grad()
        loss.backward()
        optimizer.step()
        optimizer_uv.step()



def test(epoch, net1, dataloader, optimizer, optimizer_trans, optimizer_uv):
    net1.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs1 = net1(inputs)
            score1, predicted = torch.max(outputs1, 1)
            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()
    acc = 100. * correct / total

    print("| Test Epoch #%d\t Acc Net: %.2f%%\n" % (epoch, acc))


def create_model(config):
    model = ResNet18(num_classes=config["num_classes"])
    model = model.cuda()
    return model


def main(config):

    loader = dataloader.cifar_dataloader(config['data']['dataset'], r=config['data']['noise_rate'],
                                           noise_type=config['data']['noise_type'],
                                           noise_mode=config['data']['noise_mode'],
                                           batch_size=config['data_loader']['args']['batch_size'],
                                           num_workers=config['data_loader']['args']['num_workers'],
                                           root_dir=config['data']["data_path"],
                                           noise_file='%s/%s_%s.json' % (config['data']["noise_path"], config['data']["noise_type"], config['data']["noise_rate"]))

    print('| Building net')
    net1 = create_model(config)

    train_loss = config.initialize('train_loss', module_loss).cuda()

    net_params = [{'params': [p for p in net1.parameters() if getattr(p, 'requires_grad', False)]}]
    uv_params = [{'params': train_loss.u, 'lr': config['optimizer_uv']['args']['lr_u'], 'weight_decay': config['optimizer_uv']['args']['weight_decay']},
                      {'params': train_loss.v, 'lr': config['optimizer_uv']['args']['lr_v'], 'weight_decay': config['optimizer_uv']['args']['weight_decay']}]
    trans_params = [{'params': train_loss.trans, 'lr': config['optimizer_trans']['args']["lr"], 'weight_decay': config['optimizer_trans']['args']['weight_decay']}]

    optimizer_uv = optim.SGD(uv_params)
    optimizer_trans = optim.Adam(trans_params)
    optimizer = config.initialize('optimizer', torch.optim, net_params)

    scheduler_net = config.initialize('lr_scheduler_net', torch.optim.lr_scheduler, optimizer)
    scheduler_uv = config.initialize('lr_scheduler_uv', torch.optim.lr_scheduler, optimizer_uv)
    if config['lr_scheduler_trans']['type'] != None:
        scheduler_trans = config.initialize('lr_scheduler_trans', torch.optim.lr_scheduler, optimizer_trans)
    else:
        scheduler_trans = None
    train_loader, P = loader.run('train')
    test_loader = loader.run('test')

    start = time.time()
    for epoch in range(config['trainer']['epochs']):
        train(config, epoch, net1, train_loader, optimizer, optimizer_uv, optimizer_trans, train_loss, P)
        test(epoch, net1, test_loader, optimizer, optimizer_trans, optimizer_uv)
        scheduler_net.step()
        scheduler_uv.step()
        if scheduler_trans != None:
            scheduler_trans.step()


if __name__ == '__main__':
    args = argparse.ArgumentParser(description='PyTorch Template')
    args.add_argument('-c', '--config', default=None, type=str,
                      help='config file path (default: None)')
    args.add_argument('-r', '--resume', default=None, type=str,
                      help='path to latest checkpoint (default: None)')
    args.add_argument('-d', '--device', default=None, type=str,
                      help='indices of GPUs to enable (default: all)')

    CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
    options = [
        CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
        CustomArgs(['--lr_u', '--learning_rate_u'], type=float, target=('lr_u',)),
        CustomArgs(['--lr_v', '--learning_rate_v'], type=float, target=('lr_v',)),
        CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')),
        CustomArgs(['--rate', '--noise_rate'], type=float, target=('data', 'noise_rate')),
        CustomArgs(['--type', '--noise_type'], type=str, target=('data', 'noise_type')),
        CustomArgs(['--seed', '--seed'], type=int, target=('seed',)),
        CustomArgs(['--consist', '--ratio_consistency'], type=float, target=('train_loss','args','ratio_consistency')),
        CustomArgs(['--balance', '--ratio_balance'], type=float, target=('train_loss','args','ratio_balance'))
    ]
    config = ConfigParser.get_instance(args, options)

    random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    torch.cuda.manual_seed_all(config['seed'])

    main(config)




