import torch
import torch.backends.cudnn as cudnn
import numpy as np
from argument_parser import argument_parser
from datasets.datasets import Datasets
from torch.utils.data.dataloader import DataLoader
import pickle
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import os
from os.path import join, isfile
from utils import model_loader
from metrics.manifold import tsne
from analysis.plot import check_plot_save
from analysis.plot import PlotConfusionMatrix
import metrics

def eval(model, loader, evaluate=False, eval_metrics=None, debug=False):
    eval_dict = {}
    if model is None or loader is None:
        return eval_dict

    model.eval()  
    predicted_clusters, _, true_labels, cluster_result = model.cluster(loader)
    
    if eval_metrics is None:
        return eval_dict, cluster_result

    if evaluate:

        for metric in eval_metrics:
            if metric == 'nmi':
                eval_dict.update({'nmi': metrics.cluster.nmi(true_labels, predicted_clusters)})
            elif metric == 'pur':
                eval_dict.update({'pur': metrics.cluster.purity(true_labels, predicted_clusters)})
            elif metric == 'ari':
                eval_dict.update({'ari': metrics.cluster.ari(true_labels, predicted_clusters)})
            elif metric == 'ce':
                eval_dict.update({'ce': metrics.cluster.predict_to_clustering_error(true_labels, predicted_clusters)})
            elif metric == 'c_acc':
                eval_dict.update({'c_acc': metrics.cluster.acc(true_labels, predicted_clusters)})
            elif metric == 'cm':
                cm = metrics.cluster.predict_to_confusion(true_labels, predicted_clusters)
                cm = metrics.cluster.maximize_trace(cm)
                eval_dict.update({'cm': cm})
                eval_dict.update({'cm_cutoff': int(max(true_labels) + 1)})

        if debug:
            if 'nmi' in eval_dict.keys(): 
                print("Normalized Mutual Information (NMI): %0.3f" % eval_dict['nmi'])
            elif 'pur' in eval_dict.keys():
                print("Purity: %0.3f" % eval_dict['pur'])
            elif 'ari' in eval_dict.keys():
                print("Adjusted Rand Index (ARI): %0.3f" % eval_dict['ari'])
            elif 'ce' in eval_dict.keys():
                print("Clustering Error (CE): %0.3f" % eval_dict['ce'])
            elif 'c_acc' in eval_dict.keys():
                print("Clustering Accuracy: %0.3f" % eval_dict['c_acc'])

    return eval_dict, cluster_result

args = argument_parser()

out_folder = args.out_folder if args.out_folder.endswith("/") else args.out_folder + "/"
if not os.path.exists(os.path.dirname(out_folder)):
    os.makedirs(os.path.dirname(out_folder), exist_ok=True)

use_cuda = torch.cuda.is_available() and args.cuda

if use_cuda:
    torch.cuda.init()

device = torch.device('cuda:0' if use_cuda else 'cpu')
ngpu = int(args.ngpu)

root = args.root
test_root = args.test_root
dataset_path = args.dataset
debug = args.debug
n_samples = args.n_samples

save = args.save
load = args.load
model = args.model

if not load or model is None:
    #TODO
    print("For now, you must have a pre-trained model beforehand to visualize the relevances.")
    exit(0)

combined_model, _, _, _, _, _ = model_loader.load_autoencodersom_model(model, device)

if use_cuda:
    combined_model.cuda()
    cudnn.benchmark = True

dataset = Datasets(dataset=dataset_path, root_folder=root,
                   debug=debug, n_samples=n_samples)

train_loader = DataLoader(dataset.train_data, shuffle=True)
test_loader = DataLoader(dataset.test_data, shuffle=False)


eval_dict, cluster_result = eval(combined_model, test_loader, evaluate=True, eval_metrics=['nmi', 'pur', 'ari', 'ce', 'c_acc', 'cm'], debug=False)

if 'nmi' in eval_dict.keys(): 
    print("Normalized Mutual Information (NMI): %0.3f" % eval_dict['nmi'])
if 'pur' in eval_dict.keys():
    print("Purity: %0.3f" % eval_dict['pur'])
if 'ari' in eval_dict.keys():
    print("Adjusted Rand Index (ARI): %0.3f" % eval_dict['ari'])
if 'ce' in eval_dict.keys():
    print("Clustering Error (CE): %0.3f" % eval_dict['ce'])
if 'c_acc' in eval_dict.keys():
    print("Clustering Accuracy: %0.3f" % eval_dict['c_acc'])

combined_model.write_output(join(out_folder, dataset_path + '.results'), cluster_result)
PlotConfusionMatrix().save_cm(join(out_folder, args.dataset.split('.')[0] + '.png'), 
                                    eval_dict['cm'], eval_dict['cm_cutoff'])

PlotConfusionMatrix().save_txt(join(out_folder, args.dataset.split('.')[0] + '.txt'), eval_dict['cm'][0:eval_dict['cm_cutoff']])