import os
import sys
import torch
import torch.nn as nn
import numpy as np
from itertools import chain
from copy import deepcopy
from torch.utils.data import DataLoader
from datetime import datetime

# PCA
from sklearn.decomposition import PCA

# confusion matrix
from sklearn.metrics import confusion_matrix

# tsne
from sklearn.manifold import TSNE
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt


def make_loader(data, args, train=True):
    if train:
        return DataLoader(data, batch_size=args.batch_size,
                            num_workers=args.num_workers,
                            pin_memory=args.pin_memory,
                            shuffle=True)
    else:
        return DataLoader(data, batch_size=args.test_batch_size,
                            num_workers=args.num_workers,
                            pin_memory=args.pin_memory,
                            shuffle=False)

class Criterion(nn.Module):
    def __init__(self, args, net):
        super(Criterion, self).__init__()
        self.args = args
        if args.loss_f == 'ce':
            self.criterion = nn.CrossEntropyLoss()
        elif args.loss_f == 'bce':
            self.criterion = nn.BCELoss()
        elif args.loss_f == 'nll':
            self.criterion = nn.NLLLoss()
        else:
            NotImplementedError("Loss {} is not defined".format(args.loss_f))

        self.seen_classes = net.seen_classes

    def forward(self, x, labels):
        labels = self.convert_lab(labels)
        if self.args.loss_f == 'bce':
            return self.criterion(torch.sigmoid(x), labels)
        elif self.args.loss_f == 'nll':
            return self.criterion(nn.LogSoftmax(dim=1)(x), labels)
        else: # 'ce'
            return self.criterion(x, labels)

    def convert_lab(self, labels):
        if self.args.loss_f == 'bce':
            n_cls = len(self.seen_classes)
            labels = torch.eye(n_cls).to(self.args.device)[labels]
            return labels
        else: # 'ce', 'nll'
            return labels

class Logger:
    def __init__(self, args, name=None):
        self.init = datetime.now()
        self.args = args
        if name is None:
            self.name = self.init.strftime("%m|%d|%Y %H|%M|%S")
        else:
            self.name = name

        self.args.dir = self.name

        self._make_dir()

    def now(self):
        time = datetime.now()
        diff = time - self.init
        self.print(time.strftime("%m|%d|%Y %H|%M|%S"), f" | Total: {diff}")

    def print(self, *object, sep='', end='\n', flush=False):
        print(*object, sep=sep, end=end, file=sys.stdout, flush=flush)

        with open(self.dir() + '/result.txt', 'a') as f:
            print(*object, sep=sep, end=end, file=f, flush=flush)

    def _make_dir(self):
        if not os.path.isdir('./logs'):
            os.mkdir('./logs')
        if not os.path.isdir('./logs/{}'.format(self.name)):
            os.mkdir('./logs/{}'.format(self.name))

    def dir(self):
        return './logs/{}/'.format(self.name)

    def time_interval(self):
        self.print("Total time spent: {}".format(datetime.now() - self.init))

def print_result(mat, task_id, type, print=print):
    if type == 'acc':
        # Print accuracy
        for i in range(task_id + 1):
            for j in range(task_id + 1):
                acc = mat[i, j]
                if acc != -100:
                    print("{:.2f}\t".format(acc), end='')
                else:
                    print("\t", end='')
            print("{:.2f}".format(mat[i, -1]))
    elif type == 'forget':
        # Print forgetting and average incremental accuracy
        for i in range(task_id + 1):
            acc = mat[-1, i]
            if acc != -100:
                print("{:.2f}\t".format(acc), end='')
            else:
                print("\t", end='')
        print("{:.2f}".format(mat[-1, -1]))
        if task_id > 0:
            forget = np.mean(mat[-1, :task_id])
            print("Average Forgetting: {:.2f}".format(forget))
    else:
        ValueError("Type must be either 'acc' or 'forget'")

def tsne(train_f_cross, train_y_cross, name='tsne',
         n_components=2, verbose=0, learning_rate=1, perplexity=9, n_iter=1000, logger=None):
    """ train_f_cross: X, numpy array. train_y_cross: y, numpy array """
    num_y = len(list(set(train_y_cross)))

    tsne = TSNE(n_components=n_components, verbose=verbose,
                learning_rate=learning_rate, perplexity=perplexity,
                n_iter=n_iter)
    tsne_results = tsne.fit_transform(train_f_cross)

    df_subset = pd.DataFrame(data={'tsne-2d-one': tsne_results[:, 0],
                                    'tsne-2d-two': tsne_results[:, 1]})
    df_subset['y'] = train_y_cross

    plt.figure(figsize=(16,10))
    sn.scatterplot(
        x="tsne-2d-one", y="tsne-2d-two",
        hue="y",
        palette=sn.color_palette("hls", num_y),
        data=df_subset,
        legend="full",
        alpha=0.3
    )

    dir = '' if logger is None else logger.dir()

    plt.savefig(dir + name)
    plt.close()

def plot_confusion(true_lab, pred_lab, label_names, task_id=None, p_task_id=None, name=None, print=print, logger=None):
    classes = sorted(set(np.concatenate([true_lab, pred_lab])))
    labs = []
    for c in classes:
        labs.append(label_names[c])
    plt.figure(figsize=(15, 14))
    cm = confusion_matrix(true_lab, pred_lab)
    hmap = sn.heatmap(cm, annot=True)
    hmap.set_xticks(np.arange(len(classes)) + 0.5)
    hmap.set_yticks(np.arange(len(classes)) + 0.5)
    hmap.set_xticklabels(labs, rotation=90)
    hmap.set_yticklabels(labs, rotation=0)

    dir = '' if logger is None else logger.dir() # if None, save into current folder
    print = logger.print if logger is not None else print

    if task_id is not None:
        plt.savefig(dir + "Total Task {}, current task {} is learned".format(task_id, p_task_id))
    else:
        plt.savefig(dir + name)
    plt.close()

    if task_id is not None:
        print("{}/{} | upper/lower triangular sum: {}/{}".format(task_id, p_task_id,
                                    np.triu(cm, 1).sum(), np.tril(cm, -1).sum()))
    else:
        print("Upper/lower triangular sum: {}/{}".format(np.triu(cm, 1).sum(),
                                                        np.tril(cm, -1).sum()))

