#import comet_ml
import argparse
import collections
import sys
import requests
import socket
import torch
import data_loader.data_loaders as module_data
import model.loss as module_loss
import model.metric as module_metric
import model.model as module_arch
from parse_config import ConfigParser
from trainer import Trainer
from collections import OrderedDict
import random
import torch.nn as nn


class CombinedArchitectureSingle(nn.Module):
    """
    Class combining two equal neural network architectures.
    """
    def __init__(self, single_architecture, cost_function_v=1):
        super(CombinedArchitectureSingle, self).__init__()
        self.div_to_act_func = {
            "Jensen-Shannon": nn.Sigmoid(),
            "KL": nn.Identity(),
            "SL": nn.Sigmoid(),
            "JS_s": nn.Softmax(),
            "SL_s": nn.Softmax(),
        }
        self.cost_function_version = cost_function_v
        self.single_architecture = single_architecture
        self.final_activation = self.div_to_act_func[cost_function_v]

    def forward(self, input_tensor_1):
        intermediate_1 = self.single_architecture(input_tensor_1)
        output_tensor_1 = self.final_activation(intermediate_1)
        return output_tensor_1


def build_T(noise_type, percentage, num_classes, e_s, cifar10=None):
    if noise_type == "symm":
        return [percentage/(num_classes-1)]*num_classes
    elif noise_type == "binary":
        return e_s
    elif noise_type == "custom_T_low":
        return [[0.82, 0.03, 0.01, 0.023, 0.017, 0.022, 0.021, 0.018, 0.019, 0.02],
                [0.02, 0.83, 0.01, 0.023, 0.017, 0.022, 0.021, 0.018, 0.019, 0.02],
                [0.02, 0.03, 0.81, 0.023, 0.017, 0.022, 0.021, 0.018, 0.019, 0.02],
                [0.02, 0.03, 0.01, 0.823, 0.017, 0.022, 0.021, 0.018, 0.019, 0.02],
                [0.02, 0.03, 0.01, 0.023, 0.817, 0.022, 0.021, 0.018, 0.019, 0.02],
                [0.02, 0.03, 0.01, 0.023, 0.017, 0.822, 0.021, 0.018, 0.019, 0.02],
                [0.02, 0.03, 0.01, 0.023, 0.017, 0.022, 0.821, 0.018, 0.019, 0.02],
                [0.02, 0.03, 0.01, 0.023, 0.017, 0.022, 0.021, 0.818, 0.019, 0.02],
                [0.02, 0.03, 0.01, 0.023, 0.017, 0.022, 0.021, 0.018, 0.819, 0.02],
                [0.02, 0.03, 0.01, 0.023, 0.017, 0.022, 0.021, 0.018, 0.019, 0.82]]
    elif noise_type == "custom_T_high":
        return [[0.46, 0.07, 0.04, 0.05, 0.06, 0.04, 0.06, 0.07, 0.08, 0.07],
                [0.05, 0.48, 0.04, 0.05, 0.06, 0.04, 0.06, 0.07, 0.08, 0.07],
                [0.05, 0.07, 0.45, 0.05, 0.06, 0.04, 0.06, 0.07, 0.08, 0.07],
                [0.05, 0.07, 0.04, 0.46, 0.06, 0.04, 0.06, 0.07, 0.08, 0.07],
                [0.05, 0.07, 0.04, 0.05, 0.47, 0.04, 0.06, 0.07, 0.08, 0.07],
                [0.05, 0.07, 0.04, 0.05, 0.06, 0.45, 0.06, 0.07, 0.08, 0.07],
                [0.05, 0.07, 0.04, 0.05, 0.06, 0.04, 0.47, 0.07, 0.08, 0.07],
                [0.05, 0.07, 0.04, 0.05, 0.06, 0.04, 0.06, 0.48, 0.08, 0.07],
                [0.05, 0.07, 0.04, 0.05, 0.06, 0.04, 0.06, 0.07, 0.49, 0.07],
                [0.05, 0.07, 0.04, 0.05, 0.06, 0.04, 0.06, 0.07, 0.08, 0.48]]


def log_params(conf: OrderedDict, parent_key: str = None):
    for key, value in conf.items():
        if parent_key is not None:
            combined_key = f'{parent_key}-{key}'
        else:
            combined_key = key

        if not isinstance(value, OrderedDict):
            mlflow.log_param(combined_key, value)
        else:
            log_params(value, combined_key)


def main(config: ConfigParser):

    logger = config.get_logger('train')
    data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size=config['data_loader']['args']['batch_size'],
        shuffle=config['data_loader']['args']['shuffle'],
        validation_split=config['data_loader']['args']['validation_split'],
        num_batches=config['data_loader']['args']['num_batches'],
        training=True,
        num_workers=config['data_loader']['args']['num_workers'],
        pin_memory=config['data_loader']['args']['pin_memory'],
        T=config['data_loader']['args']['T']
    )

    valid_data_loader = data_loader.split_validation()

    test_data_loader = getattr(module_data, config['data_loader']['type'])(
        config['data_loader']['args']['data_dir'],
        batch_size=128,
        shuffle=False,
        validation_split=0.0,
        training=False,
        num_workers=2,
        T=config['data_loader']['args']['T']
    ).split_validation()

    # build model architecture, then print to console
    model = config.initialize('arch', module_arch)
    if config["train_loss"]["args"]["change_var"]:
        model = CombinedArchitectureSingle(model, cost_function_v=config["train_loss"]["args"]["div"])

    # get function handles of loss and metrics
    logger.info(config.config)
    if hasattr(data_loader.dataset, 'num_raw_example'):
        num_examp = data_loader.dataset.num_raw_example
    else:
        num_examp = len(data_loader.dataset)

    config['train_loss']['args']['num_examp'] = num_examp
    if not config['trainer']['asym'] and not config['trainer']['instance'] and not config['trainer']['real'] \
            and not config['trainer']['binary'] and not config['trainer']['custom_T_low'] and not config['trainer'][
        'custom_T_high']:
        config['train_loss']['args']['T'] = build_T("symm", config['trainer']['percent'],
                                                    config['train_loss']['args']['num_classes'], None,
                                                    cifar10=data_loader)
    elif config['trainer']['custom_T_low']:
        config['train_loss']['args']['T'] = build_T("custom_T_low", config['trainer']['percent'],
                                                    config['train_loss']['args']['num_classes'], None,
                                                    cifar10=data_loader)
    elif config['trainer']['custom_T_high']:
        config['train_loss']['args']['T'] = build_T("custom_T_high", config['trainer']['percent'],
                                                    config['train_loss']['args']['num_classes'], None,
                                                    cifar10=data_loader)
    train_loss = config.initialize('train_loss', module_loss)

    val_loss = config.initialize('val_loss', module_loss)
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    # if param does not require grad -> do not include it in the trainable params
    trainable_params = [{'params': [p for p in model.parameters() if getattr(p, 'requires_grad', False)]}]

    optimizer = config.initialize('optimizer', torch.optim, trainable_params)
    lr_scheduler = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer)

    trainer = Trainer(model, train_loss, metrics, optimizer,
                      config=config,
                      data_loader=data_loader,
                      valid_data_loader=valid_data_loader,
                      test_data_loader=test_data_loader,
                      lr_scheduler=lr_scheduler,
                      val_criterion=val_loss)
    trainer.train()
    logger = config.get_logger('trainer', config['trainer']['verbosity'])
    cfg_trainer = config['trainer']


if __name__ == '__main__':
    args = argparse.ArgumentParser(description='PyTorch Template')
    args.add_argument('-c', '--config', default='config_cifar10.json', type=str,
                      help='config file path (default: None)')
    args.add_argument('-r', '--resume', default=None, type=str,
                      help='path to latest checkpoint (default: None)')
    args.add_argument('-d', '--device', default=None, type=str,
                      help='indices of GPUs to enable (default: all)')

    # custom cli options to modify configuration from default values given in json file.
    CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
    options = [
        CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
        CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')),
        CustomArgs(['--percent', '--percent'], type=float, target=('trainer', 'percent')),
        CustomArgs(['--asym', '--asym'], type=bool, target=('trainer', 'asym')),
        CustomArgs(['--instance', '--instance'], type=bool, target=('trainer', 'instance')),
        CustomArgs(['--name', '--exp_name'], type=str, target=('name',)),
        CustomArgs(['--seed', '--seed'], type=int, target=('seed',)),
        CustomArgs(['--consist', '--ratio_consistency'], type=float, target=('train_loss', 'args', 'ratio_consistency')),
        CustomArgs(['--balance', '--ratio_balance'], type=float, target=('train_loss', 'args', 'ratio_balance'))
    ]
    config = ConfigParser.get_instance(args, options)

    divergences = ["Jensen-Shannon", "SL"] #["JS_s", "SL_s"] 
    noise_percentages = [0.2, 0.3, 0.4] #[0.2, 0.4, 0.6, 0.8] #[0.2, 0.3, 0.4] # [0.2, 0.4, 0.6, 0.8] #   #
    #real_noise_types = ["aggre", "rand1", "rand2", "rand3", "worst", "clean"]
    #real_noise_types = ["clean100", "noisy100"]
    #real_noise_types = ["clean100"]

    #for real_noise in real_noise_types:
    for noise_percentage in noise_percentages:
        for div in divergences:
            random.seed(config['seed'])
            torch.manual_seed(config['seed'])
            torch.cuda.manual_seed_all(config['seed'])
            config['train_loss']['args']['div'] = div
            config['trainer']['percent'] = noise_percentage
            #config['trainer']['real'] = real_noise
            main(config)
