import argparse
import argparse
import os
import shutil

import numpy as np
import pandas as pd
import torch
from scipy.optimize import linear_sum_assignment
from scipy.special import softmax
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, \
    classification_report

from utils.saver import Saver

import sys
sys.path.append('../../../')
sys.path.append('../../../pipeline')

from pipeline.cb_evaluation_api import class_evaluation

class GPU_id:
    id = 3

######################################################################################################################
columns = shutil.get_terminal_size().columns
device = torch.device("cuda:{}".format(GPU_id.id) if torch.cuda.is_available() else "cpu")


#####################################################################################################################

class SaverSlave(Saver):
    def __init__(self, path):
        super(Saver)

        self.path = path
        self.makedir_()
        # self.make_log()


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def check_ziplen(l, n):
    if len(l) % n != 0:
        l += [l[-1]]
        return check_ziplen(l, n)
    else:
        return l


def remove_duplicates(sequence):
    unique = []
    [unique.append(item) for item in sequence if item not in unique]
    return unique


def map_abg(x):
    if x == [0, 1, 0]:
        return r'$\mathcal{L}_c$'
    elif x == [1, 0, 0]:
        return r'$\mathcal{L}_{ae}$'
    elif x == [1, 1, 0]:
        return r'$\mathcal{L}_c + \mathcal{L}_{ae}$'
    elif x == [0, 1, 1]:
        return r'$\mathcal{L}_c + \mathcal{L}_{cc}$'
    elif x == [1, 1, 1]:
        return r'$\mathcal{L}_c + \mathcal{L}_{ae} + \mathcal{L}_{cc}$'
    else:
        raise ValueError


def map_losstype(x):
    if x == 0:
        return 'Symm'
    else:
        return 'Asymm_{}'.format(x)


def map_abg_main(x):
    if x is None:
        return 'Variable'
    else:
        return '_'.join([str(int(j)) for j in x])


def remove_empty_dir(path):
    try:
        os.rmdir(path)
    except OSError:
        pass


def remove_empty_dirs(path):
    for root, dirnames, filenames in os.walk(path, topdown=False):
        for dirname in dirnames:
            remove_empty_dir(os.path.realpath(os.path.join(root, dirname)))


def add_noise(x, sigma=0.2, mu=0.):
    noise = mu + torch.randn(x.size()) * sigma
    noisy_x = x + noise
    return noisy_x


def readable(num):
    for unit in ['', 'k', 'M']:
        if abs(num) < 1e3:
            return "%3.3f%s" % (num, unit)
        num /= 1e3
    return "%.1f%s" % (num, 'B')


# Unique labels
def categorizer(y_cont, y_discrete):
    Yd = np.diff(y_cont, axis=0)
    Yd = (Yd > 0).astype(int).squeeze()
    C = pd.Series([x + y for x, y in
                   zip(list(y_discrete[1:].astype(int).astype(str)), list((Yd).astype(str)))]).astype(
        'category')
    return C.cat.codes


def reset_seed_(seed):
    # Resetting SEED to fair comparison of results
    print('Settint seed: {}'.format(seed))
    torch.manual_seed(seed)
    if device != 'cpu':
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)


def reset_model(model):
    print('Resetting model parameters...')
    for layer in model.modules():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()
    return model


def append_results_dict(main, sub):
    for k, v in zip(sub.keys(), sub.values()):
        main[k].append(v)
    return main


def flip_label(target, ratio, pattern=0):
    """
    Induce label noise by randomly corrupting labels
    :param target: list or array of labels
    :param ratio: float: noise ratio
    :param pattern: flag to choose which type of noise.
            0 or mod(pattern, #classes) == 0 = symmetric
            int = asymmetric
            -1 = flip
    :return:
    """
    assert 0 <= ratio < 1

    target = np.array(target).astype(int)
    label = target.copy()
    n_class = len(np.unique(label))

    if type(pattern) is int:
        for i in range(label.shape[0]):
            # symmetric noise
            if (pattern % n_class) == 0:
                p1 = ratio / (n_class - 1) * np.ones(n_class)
                p1[label[i]] = 1 - ratio
                label[i] = np.random.choice(n_class, p=p1)
            elif pattern > 0:
                # Asymm
                label[i] = np.random.choice([label[i], (target[i] + pattern) % n_class], p=[1 - ratio, ratio])
            else:
                # Flip noise
                label[i] = np.random.choice([label[i], 0], p=[1 - ratio, ratio])

    elif type(pattern) is str:
        raise ValueError

    mask = np.array([int(x != y) for (x, y) in zip(target, label)])

    return label, mask


def cluster_accuracy(y_true, y_predicted, cluster_number=None):
    """
    Calculate clustering accuracy after using the linear_sum_assignment function in SciPy to
    determine reassignments.

    :param y_true: list of true cluster numbers, an integer array 0-indexed
    :param y_predicted: list  of predicted cluster numbers, an integer array 0-indexed
    :param cluster_number: number of clusters, if None then calculated from input
    :return: reassignment dictionary, clustering accuracy
    """
    if cluster_number is None:
        cluster_number = (
                max(y_predicted.max(), y_true.max()) + 1
        )  # assume labels are 0-indexed
    count_matrix = np.zeros((cluster_number, cluster_number), dtype=np.int64)
    for i in range(y_predicted.size):
        count_matrix[y_predicted[i], y_true[i]] += 1

    row_ind, col_ind = linear_sum_assignment(count_matrix.max() - count_matrix)
    reassignment = dict(zip(row_ind, col_ind))
    accuracy = count_matrix[row_ind, col_ind].sum() / y_predicted.size
    return reassignment, accuracy


def evaluate_model_multi(model, dataloder, y_true, x_true,
                         metrics=('mae', 'mse', 'rmse', 'std_ae', 'smape', 'rae', 'mbrae', 'corr', 'r2'), n_class=2):
    xhat, yhat = predict_multi(model, dataloder)

    # Classification
    y_hat_proba = softmax(yhat, axis=1)
    y_hat_labels = np.argmax(y_hat_proba, axis=1)

    y_true = np.asarray(y_true, dtype=np.int64)
    y_pred = np.asarray(y_hat_labels, dtype=np.int64)
    assert y_true.ndim == 1
    assert y_pred.ndim == 1
    assert y_true.shape == y_pred.shape

    index = class_evaluation(y_true, y_pred, n_class)

    # # Compute the confusion matrix
    total_sample = y_pred.shape[0]
    tp = (y_true & y_pred).sum()
    fp = y_pred.sum() - tp
    fn = y_true.sum() - tp
    tn = total_sample - tp - fp - fn

    return index.acc, index.pre, index.rec, index.f1, tp, tn, fp, fn


def evaluate_model(model, dataloder, y_true, n_class):
    yhat = predict(model, dataloder)

    # Classification
    y_hat_proba = softmax(yhat, axis=1)
    y_hat_labels = np.argmax(y_hat_proba, axis=1)

    y_true = np.asarray(y_true, dtype=np.int64)
    y_pred = np.asarray(y_hat_labels, dtype=np.int64)
    assert y_true.ndim == 1
    assert y_pred.ndim == 1
    assert y_true.shape == y_pred.shape

    index = class_evaluation(y_true, y_pred, n_class)

    return index.acc, index.pre, index.rec, index.f1


def evaluate_class_recons(model, x, Y, Y_clean, dataloader, ni, saver, network='Model', datatype='Train', correct=False, n_class=2):
    print(f'{datatype} score')
    if Y_clean is not None:
        T = confusion_matrix(Y_clean, Y)
    else:
        T = None
    results_dict = dict()

    title_str = f'{datatype} - ratio:{ni} - correct:{str(correct)}'

    # results, yhat_proba, yhat, acc, f1, recons, _, ae_results = evaluate_model_multi(model, dataloader, Y, x)
    accuracy, precision, recall, f1, tp, tn, fp, fn = evaluate_model_multi(model, dataloader, Y, x, n_class=n_class)

    results_dict['acc'] = accuracy
    results_dict['pre'] = precision
    results_dict['rec'] = recall
    results_dict['f1'] = f1
    results_dict['fp'] = fp
    results_dict['fn'] = fn
    results_dict['tp'] = tp
    results_dict['tn'] = tn

    # saver.append_str([f'{datatype}Set', 'Classification report:', results])
    # saver.append_str(['AutoEncoder results:'])
    # saver.append_dict(ae_results)
    return results_dict


def evaluate_class(args, model, x, Y, Y_clean, dataloader, ni, saver, network='Model', datatype='Train', correct=False,
                   plt_cm=True, plt_lables=True):
    print(f'{datatype} score')
    if Y_clean is not None:
        T = confusion_matrix(Y_clean, Y)
    else:
        T = None
    results_dict = dict()

    title_str = f'{datatype} - ratio:{ni} - correct:{str(correct)}'

    accuracy, precision, recall, f1 = evaluate_model(model, dataloader, Y, args.n_class)

    results_dict['acc'] = accuracy
    results_dict['pre'] = precision
    results_dict['rec'] = recall
    results_dict['f1'] = f1
    # saver.append_str([f'{datatype}Set', 'Classification report:', results])
    return results_dict


def predict(model, test_data):
    prediction = []
    with torch.no_grad():
        model.eval()
        for data in test_data:
            data = data[0]
            data = data.float().to(device)
            output = model(data)
            prediction.append(output.cpu().numpy())

    prediction = np.concatenate(prediction, axis=0)
    return prediction


def predict_multi(model, test_data):
    reconstruction = []
    prediction = []
    with torch.no_grad():
        model.eval()
        for data in test_data:
            data = data[0].float().to(device)
            out_ae, out_class, embedding = model(data)
            prediction.append(out_class.cpu().numpy())
            reconstruction.append(out_ae.cpu().numpy())

    prediction = np.concatenate(prediction, axis=0)
    reconstruction = np.concatenate(reconstruction, axis=0)
    return reconstruction, prediction


def __f_beta_score__(pre, rec, beta=1.0):
    if pre > 0 or rec > 0:
        return (1 + beta ** 2) * pre * rec / (beta ** 2 * pre + rec)
    else:
        return 0