import os
import json 
from uuid import uuid4 

import numpy as np

import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score

from neuralfaults.utils.metrics import smape, rmse, rmsle, r2, mae, sc_mse, CEExtended

from neuralfaults.models.cnn import ShallowCNN, DeepCNN
from neuralfaults.models.ffnn import ShallowFNN, DeepFNN
from neuralfaults.models.rnn import ShallowRNN, DeepRNN
from neuralfaults.models.lstm import ShallowLSTM, DeepLSTM
from neuralfaults.models.encdec import (ShallowEncDec, DeepEncDec, EncDecSkip,
                          EncDecRNNSkip, EncDecBiRNNSkip,
                          EncDecDiagBiRNNSkip, LightEncDec, LightEncDecSkip)
from neuralfaults.models.unet import UNET_1D

from neuralfaults.impute_models.grud import GRUD
from neuralfaults.impute_models.mrnn import MRNN
from neuralfaults.impute_models.brits import BRITS

from neuralfaults.impute_models.gain import GAINGen, GAINDisc
from neuralfaults.impute_models.gan2stage import GAN2StageGen, GAN2StageDisc, GAN2StageZ
from neuralfaults.impute_models.e2egan import E2EDisc, E2EGen
from neuralfaults.impute_models.sgan import SGANDisc, SGANGen


def get_file_names(opt):
    """Get file fully qualified names to write weights and logs.

    Args:
        opt (argparse.ArgumentParser): Parsed arguments.

    Returns:
        tuple: weight path and log path.

    Raises:        ExceptionName: Why the exception is raised.

    Examples
        Examples should be written in doctest format, and
        should illustrate how to use the function/class.
        >>>

    """
    suffix = '_act_' + opt.act
    suffix += '_stride_' + str(opt.stride)
    suffix += '_window_' + str(opt.window)
    suffix += '_inpQuants_' + opt.inp_quants
    suffix += '_outQuants_' + opt.out_quants
    suffix += '_lr_' + str(opt.lr)
    suffix += '_batchSize_' + str(opt.batch_size)
    suffix += '_epochs_' + str(opt.epochs)
    suffix += '_loss_' + str(opt.loss)
    
    if len(opt.model):
        dir_path = os.path.join(opt.weights_dir, opt.data_dir.split('/')[-1], opt.dataset_name, opt.model)
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
    if len(opt.imputed_model):
        dir_path = os.path.join(opt.weights_dir, opt.data_dir.split('/')[-1], opt.dataset_name, opt.imputed_model)
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

    fname = uuid4().hex
    weight_path = os.path.join(dir_path, fname + '.pt')
    log_path = os.path.join(dir_path, fname + '.log')

    arguments = vars(opt)
    fout = open(os.path.join(dir_path, fname + '.json'), 'w')
    json.dump(arguments, fout)
    fout.close()

    return weight_path, log_path


def initialize_metrics(opt):
    """Generates a dictionary of metrics with metrics as keys and
       empty lists as values.

    Args:
        opt (argparse.ArgumentParser): Parsed arguments.

    Returns:
        dict: A dictionary of metrics.

    Raises:        ExceptionName: Why the exception is raised.

    Examples
        Examples should be written in doctest format, and
        should illustrate how to use the function/class.
        >>>

    """
    metrics = {'loss': []}
    for quant in opt.out_quants.split(','):
        for metric in opt.metrics.split(','):
            metrics[f'{quant}-{metric}'] = []

    return metrics


def get_mean_metrics(metrics_dict):
    """Takes a dictionary of lists for metrics and returns dict of mean values.

    Args:
        metrics_dict (dict): A dictionary of metrics.

    Returns:
        dict: Dictionary of floats that reflect mean metric value.

    Raises:        ExceptionName: Why the exception is raised.

    Examples
        Examples should be written in doctest format, and
        should illustrate how to use the function/class.
        >>>

    """
    return {k: np.mean(v) for k, v in metrics_dict.items()}


def transform_tensor(tensor):
    r"""
    Transform all tensor types to numpy ndarray.
    """
    if isinstance(tensor, torch.Tensor):
        if tensor.is_cuda:
            return tensor.data.cpu().numpy()
        else:
            return tensor.data.numpy()
    if isinstance(tensor, np.ndarray):
        return tensor
    if isinstance(tensor, list):
        return np.asarray(tensor)


def compute_metrics(metrics_dict, loss, predicted, target, opt):
    """Updates metrics dictionary with batch metrics.

    Args:
        metric_dict (dict): Dictionary of metrics.
        loss (torch.float): Loss value.
        smape (torch.float): SMAPE value.
        opt (argparse.ArgumentParser): Parsed arguments.

    Returns:
        type: Description of returned object.

    Raises:        ExceptionName: Why the exception is raised.

    Examples
        Examples should be written in doctest format, and
        should illustrate how to use the function/class.
        >>>

    """
    metrics_map = {'smape': smape, 'r2': r2, 'rmse': rmse, 'mae': mae, 'rmsle': rmsle,
                   'acc': accuracy_score, 'f1': f1_score}
    metrics_dict['loss'].append(loss.item())

    predicted = transform_tensor(predicted)
    target = transform_tensor(target)

    if 'acc' in opt.metrics or 'f1' in opt.metrics:
        predicted = np.argmax(predicted, axis=1).flatten()
        target = target.flatten()

    quants = opt.out_quants.split(',')
    metrics = opt.metrics.split(',')
    for q in range(len(quants)):
        for m in metrics:
            if m != 'acc' and m != 'f1':
                metrics_dict[f'{quants[q]}-{m}'].append(metrics_map[m](target[:, q, :], predicted[:, q, :]))
            elif m == 'acc':
                metrics_dict[f'{quants[q]}-{m}'].append(metrics_map[m](target, predicted))
            elif m == 'f1':
                metrics_dict[f'{quants[q]}-{m}'].append(metrics_map[m](target, predicted, average='macro'))

    return metrics_dict


def get_model(opt):
    """Get model.

    Args:
        opt (argparse.ArgumentParser): Parsed arguments.

    Returns:
        torch.nn.module: Model definition.

    Raises:        ExceptionName: Why the exception is raised.

    Examples
        Examples should be written in doctest format, and
        should illustrate how to use the function/class.
        >>>

    """
    inp_channels = len(opt.inp_quants.split(','))
    out_channels = len(opt.out_quants.split(','))
    
    if opt.dataset_name == 'FaultVibration' and opt.loss == 'ce':
        out_channels = 5

    act = opt.act

    if opt.model == 'shallow_fnn':
        inp_len = inp_channels * opt.window
        model = ShallowFNN(inp_len, out_channels, act)
    elif opt.model == 'deep_fnn':
        inp_len = inp_channels * opt.window
        model = DeepFNN(inp_len, out_channels, act)
    elif opt.model == 'shallow_cnn':
        model = ShallowCNN(inp_channels, out_channels, act)
    elif opt.model == 'deep_cnn':
        model = DeepCNN(inp_channels, out_channels, act)
    elif opt.model == 'shallow_rnn':
        model = ShallowRNN(inp_channels, out_channels, opt.hidden_size, act)
    elif opt.model == 'deep_rnn':
        model = DeepRNN(inp_channels, out_channels, opt.hidden_size, act)
    elif opt.model == 'shallow_lstm':
        model = ShallowLSTM(inp_channels, out_channels, opt.hidden_size, act)
    elif opt.model == 'deep_lstm':
        model = DeepLSTM(inp_channels, out_channels, opt.hidden_size, act)
    elif opt.model == 'light_encdec':
        model = LightEncDec(inp_channels, out_channels, act)
    elif opt.model == 'shallow_encdec':
        model = ShallowEncDec(inp_channels, out_channels, act)
    elif opt.model == 'deep_encdec':
        model = DeepEncDec(inp_channels, out_channels, act)
    elif opt.model == 'light_encdec_skip':
        model = LightEncDecSkip(inp_channels, out_channels, act)
    elif opt.model == 'encdec_skip':
        model = EncDecSkip(inp_channels, out_channels, act)
    elif opt.model == 'encdec_rnn_skip':
        model = EncDecRNNSkip(inp_channels, out_channels, act)
    elif opt.model == 'encdec_birnn_skip':
        model = EncDecBiRNNSkip(inp_channels, out_channels, act)
    elif opt.model == 'encdec_diag_birnn_skip':
        model = EncDecDiagBiRNNSkip(inp_channels, out_channels, act)
    elif opt.model == 'unet':
        model = UNET_1D(inp_channels, out_channels, 128, 7, 3)
    elif opt.impute_model == 'grud':
        model = GRUD(inp_channels, 128, out_channels)
    elif opt.impute_model == 'mrnn':
        model = MRNN(inp_channels, opt.window, 32, device=opt.gpu)
    elif opt.impute_model == 'brits':
        model = BRITS(inp_channels, 64)
    else:
        print("Incorrect model passed in argument.")
        exit()

    print ('Parameters :', sum(p.numel() for p in model.parameters()))

    if opt.gpu > -1:
        model = model.to(opt.gpu)
        if opt.num_gpus > 1:
            model = nn.DataParallel(model, device_ids=list(range(opt.num_gpus)))

    return model.cuda(opt.gpu)

def get_gain_model(opt):
    """Get gan model.

    Args:
        opt (argparse.ArgumentParser): Parsed arguments.

    Returns:
        torch.nn.module: Model definition.

    Raises:        ExceptionName: Why the exception is raised.

    Examples
        Examples should be written in doctest format, and
        should illustrate how to use the function/class.
        >>>

    """
    inp_channels = len(opt.inp_quants.split(','))
    out_channels = len(opt.out_quants.split(','))
    
    act = opt.act

    if opt.impute_model == 'gain':
        inp_len = inp_channels * opt.window
        model_g = GAINGen(inp_len)
        model_d = GAINDisc(inp_len)
    elif opt.impute_model == 'sgan':
        model_g = SGANGen(inp_channels, opt.window)
        model_d = SGANDisc(inp_channels, opt.window)
        if len(opt.weight_path):
            model_g.encdec = torch.load(opt.weight_path)
            model_d.encdec = model_g.encdec
    else:
        print("Incorrect model passed in argument.")
        exit()

    print ('Model G Parameters :', sum(p.numel() for p in model_g.parameters()))
    print ('Model D Parameters :', sum(p.numel() for p in model_d.parameters()))

    if opt.gpu > -1:
        model_g = model_g.to(opt.gpu)
        model_d = model_d.to(opt.gpu)
        if opt.num_gpus > 1:
            model_g = nn.DataParallel(model_g, device_ids=list(range(opt.num_gpus)))
            model_d = nn.DataParallel(model_d, device_ids=list(range(opt.num_gpus)))

    return model_g.cuda(opt.gpu), model_d.cuda(opt.gpu)


def get_e2e_model(opt):
    """Get gan model.

    Args:
        opt (argparse.ArgumentParser): Parsed arguments.

    Returns:
        torch.nn.module: Model definition.

    Raises:        ExceptionName: Why the exception is raised.

    Examples
        Examples should be written in doctest format, and
        should illustrate how to use the function/class.
        >>>

    """
    inp_channels = len(opt.inp_quants.split(','))
    out_channels = len(opt.out_quants.split(','))
    
    act = opt.act

    if opt.impute_model == 'e2e':
        inp_len = inp_channels * opt.window
        model_g = E2EGen(64, opt.window, inp_channels, 64, inp_channels)
        model_d = E2EDisc(opt.window, inp_channels, 64, inp_channels)
    else:
        print("Incorrect model passed in argument.")
        exit()

    print ('Model G Parameters :', sum(p.numel() for p in model_g.parameters()))
    print ('Model D Parameters :', sum(p.numel() for p in model_d.parameters()))

    if opt.gpu > -1:
        model_g = model_g.to(opt.gpu)
        model_d = model_d.to(opt.gpu)
        if opt.num_gpus > 1:
            model_g = nn.DataParallel(model_g, device_ids=list(range(opt.num_gpus)))
            model_d = nn.DataParallel(model_d, device_ids=list(range(opt.num_gpus)))

    return model_g.cuda(opt.gpu), model_d.cuda(opt.gpu)


def get_gan_model(opt):
    """Get gan model.

    Args:
        opt (argparse.ArgumentParser): Parsed arguments.

    Returns:
        torch.nn.module: Model definition.

    Raises:        ExceptionName: Why the exception is raised.

    Examples
        Examples should be written in doctest format, and
        should illustrate how to use the function/class.
        >>>

    """
    inp_channels = len(opt.inp_quants.split(','))
    out_channels = len(opt.out_quants.split(','))
    
    act = opt.act

    if opt.impute_model == 'gan2stage':
        model_g = GAN2StageGen(64, opt.window, inp_channels, 64, inp_channels)
        model_d = GAN2StageDisc(opt.window, inp_channels, 64, inp_channels)
        model_z = GAN2StageZ(64)
    else:
        print("Incorrect model passed in argument.")
        exit()

    print ('Model G Parameters :', sum(p.numel() for p in model_g.parameters()))
    print ('Model D Parameters :', sum(p.numel() for p in model_d.parameters()))
    print ('Model Z Parameters :', sum(p.numel() for p in model_z.parameters()))

    if opt.gpu > -1:
        model_g = model_g.to(opt.gpu)
        model_d = model_d.to(opt.gpu)
        model_z = model_z.to(opt.gpu)
        if opt.num_gpus > 1:
            model_g = nn.DataParallel(model_g, device_ids=list(range(opt.num_gpus)))
            model_d = nn.DataParallel(model_d, device_ids=list(range(opt.num_gpus)))
            model_z = nn.DataParallel(model_z, device_ids=list(range(opt.num_gpus)))

    return model_g.cuda(opt.gpu), model_d.cuda(opt.gpu), model_z.cuda(opt.gpu)

def get_gan_loss_functions(opt):
    if opt.loss == 'mse,bce':
        criterion_g = nn.MSELoss(reduction="mean")
        criterion_d = nn.BCEWithLogitsLoss(reduction="mean")

    return criterion_g, criterion_d

def get_loss_function(opt):
    if opt.loss == 'mse':
        criterion = nn.MSELoss()
    if opt.loss == 'sc_mse':
        criterion = sc_mse
    if opt.dataset_name == 'FaultVibration' and opt.loss == 'ce':
        criterion = CEExtended()

    return criterion


def get_model_from_weight(opt):
    model = torch.load(opt.weight_file)
    return model


class Log(object):
    """Logger class to log training metadata.

    Args:
        log_file_path (type): Log file name.
        op (type): Read or write.

    Examples
        Examples should be written in doctest format, and
        should illustrate how to use the function/class.
        >>>

    Attributes:
        log (type): Description of parameter `log`.
        op

    """
    def __init__(self, log_file_path, op='r'):
        self.log = open(log_file_path, op)
        self.op = op

    def write_model(self, model):
        self.log.write('\n##MODEL START##\n')
        self.log.write(str(model))
        self.log.write('\n##MODEL END##\n')

        self.log.write('\n##MODEL SIZE##\n')
        self.log.write(str(sum(p.numel() for p in model.parameters())))
        self.log.write('\n##MODEL SIZE##\n')

    def log_train_metrics(self, metrics, epoch):
        self.log.write('\n##TRAIN METRICS##\n')
        self.log.write('@epoch:' + str(epoch) + '\n')
        for k, v in metrics.items():
            self.log.write(k + '=' + str(v) + '\n')
        self.log.write('\n##TRAIN METRICS##\n')

    def log_validation_metrics(self, metrics, epoch):
        self.log.write('\n##VALIDATION METRICS##\n')
        self.log.write('@epoch:' + str(epoch) + '\n')
        for k, v in metrics.items():
            self.log.write(k + '=' + str(v) + '\n')
        self.log.write('\n##VALIDATION METRICS##\n')

    def close(self):
        self.log.close()
