import copy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.manifold import TSNE
from umap import UMAP
# from cuml import TSNE
# from cuml import UMAP

import torch
import optuna

import visualization_metric

class Objective(object):
    def __init__(self, exp_name, dataset_name, dataset_feature_name, method='k-gap', dataset_path='', save_path='./res', features_n_labels=None):
        self.exp_name = exp_name
        self.dataset_name = dataset_name
        self.method = method
        self.save_path = save_path

        if exp_name == 'gene':
        # cdist, x, y = torch.load(f'G:\\****\\ad\\datasets\\cdist_data/{dataset_name}_cdist_data_for_tsne2.tar')
            cdist, y, x, ind, selected_labels = torch.load(dataset_path, weights_only=False)
            x = x.cpu().numpy()
            print(type(x))
        else:
            cdist, y, ind, selected_labels = torch.load(dataset_path, weights_only=False)
            print(y)
            d_name = '-'.join(dataset_name.split('-')[:-2])
            # features, labels = torch.load(f"/home/****/autovisual/prepare_data/data/{d_name}_features_clip.tar", weights_only=False)
            # read_ = torch.load(dataset_feature_name, weights_only=False)
            # features, labels = read_[:2]
            # features, labels = features_n_labels
            # print(set(labels))

            x, y_ = torch.load(f"/mnt/data01/****/****/prepare_data/uci_tabular/downloaded_data/{d_name}_x_y.tar",
                                          weights_only=False)
            features = x.to_numpy().astype(float)
            if not type(y_) is np.ndarray:
                y_ = y_.to_numpy()
            labels = y_.reshape(-1)

            selected_indices = np.isin(labels, selected_labels)
            print(len(selected_indices))
            x_selected = features[selected_indices]
            print(len(x_selected))
            y_selected = labels[selected_indices]

            print(labels, selected_labels, ind)
            x = x_selected[ind]
            y_ = y_selected[ind]

        # print(y[:10])
        # print(y_[:10])

        # '/home/****/ad-adbench/datasets'

        self.cdist = cdist

        self.ind = ind
        self.x = x
        self.y = y


        self.hps = []

        self.scores = []
        self.embs = []


    def tsne_search(self, trial):
        # features = trial.suggest_int('features', 16, 256, 8)
        perplexity = trial.suggest_int('perplexity', 10, 80, step=1)
        # init = trial.suggest_categorical("init", ["random", 'pca'])
        init = 'random'
        # angle = trial.suggest_float("angle", 1e-6, 1.0, log=False)

        tsne = TSNE(perplexity=perplexity, n_components=2, verbose=0, init=init, random_state=42) #angle=angle,  metric="precomputed"
        print(self.x.shape)
        # z = tsne.fit_transform(self.x, knn_graph=self.cdist)
        z = tsne.fit_transform(self.x)

        # tsne = TSNE(perplexity=perplexity, n_components=2, verbose=0, init=init, #angle=angle,
        #             metric="precomputed", random_state=42, n_jobs=8)  #
        # z = tsne.fit_transform(self.cdist)

        # results = visualization_metric.calculate_all_metrics(self.x, z, self.y)
        nmi, sc = visualization_metric.get_nmi_sc(z, self.y)
        results = (nmi + sc) / 2
        print(results)
        self.scores.append(results)
        self.hps.append({
            'perplexity': perplexity,
            'init': init,
            # 'angle': angle,
        })

        self.embs.append(z)
        return results


    def tsne_search_nmi(self, trial):
        # features = trial.suggest_int('features', 16, 256, 8)
        perplexity = trial.suggest_int('perplexity', 10, 80, step=1)
        # init = trial.suggest_categorical("init", ["random", 'pca'])
        init = 'random'
        # angle = trial.suggest_float("angle", 1e-6, 1.0, log=False)

        tsne = TSNE(perplexity=perplexity, n_components=2, verbose=0, init=init, random_state=42) #angle=angle,  metric="precomputed"
        # z = tsne.fit_transform(self.x, knn_graph=self.cdist)
        z = tsne.fit_transform(self.x)

        # tsne = TSNE(perplexity=perplexity, n_components=2, verbose=0, init=init, #angle=angle,
        #             metric="precomputed", random_state=42, n_jobs=8)  #
        # z = tsne.fit_transform(self.cdist)

        # results = visualization_metric.calculate_all_metrics(self.x, z, self.y)
        nmi, sc = visualization_metric.get_nmi_sc(z, self.y)
        # results = (nmi + sc) / 2
        results = nmi
        print(results)
        self.scores.append(results)
        self.hps.append({
            'perplexity': perplexity,
            'init': init,
            # 'angle': angle,
        })

        self.embs.append(z)
        return results

    def umap_search(self, trial):
        # features = trial.suggest_int('features', 16, 256, 8)
        max_neighbors = min(100, len(self.y))
        n_neighbors = trial.suggest_int('n_neighbors', 10, max_neighbors, step=1)
        # init = trial.suggest_categorical("init", ["random", 'pca'])
        init = 'spectral'
        # angle = trial.suggest_float("angle", 1e-6, 1.0, log=False)

        tsne = UMAP(n_neighbors=n_neighbors, n_components=2, verbose=0, init=init, random_state=42) #angle=angle,  metric="precomputed"
        # z = tsne.fit_transform(self.x, knn_graph=self.cdist)
        z = tsne.fit_transform(self.x)

        # tsne = TSNE(perplexity=perplexity, n_components=2, verbose=0, init=init, #angle=angle,
        #             metric="precomputed", random_state=42, n_jobs=8)  #
        # z = tsne.fit_transform(self.cdist)

        # results = visualization_metric.calculate_all_metrics(self.x, z, self.y)
        nmi, sc = visualization_metric.get_nmi_sc(z, self.y)
        # results = (nmi + sc) / 2
        results = nmi
        print(results)
        self.scores.append(results)
        self.hps.append({
            'n_neighbors': n_neighbors,
            'init': init,
            # 'angle': angle,
        })

        self.embs.append(z)
        return results

    def plot_results(self, ad_method, root='./', log=False, is_sort=False, is_cummax=True):
        aucs = np.array(self.aucs)
        f1s = np.array(self.f1s)
        deltas = np.array(self.deltas)

        if is_sort:
            sort = deltas.argsort()
        else:
            sort = range(len(deltas))

        if is_cummax:
            ddf = pd.DataFrame({'d': deltas},
                               index=range(len(deltas))
                               )
            ddf['cummax'] = ddf.d.cummax()
            ddf['idx'] = ddf.index

            ddf_ = ddf.merge(ddf.groupby('cummax')[['idx']].first().reset_index(), on='cummax')
            sort = ddf_['idx_y'].to_numpy()

        plt.cla()

        fig, ax1 = plt.subplots()
        x = np.arange(len(aucs))
        if log:
            ax1.plot(x, np.log10(deltas[sort]), label='gap', c='b')
        else:
            ax1.plot(x, deltas[sort], label='gap', c='b')

        ax2 = ax1.twinx()
        ax2.plot(x, aucs[sort], label='aucs', c='r')
        ax2.plot(x, f1s[sort], label='f1', c='orange')

        # if self.n_noise < 1:
        #     n_noise = int(self.train_x.shape[0] * self.n_noise / (1 + self.n_noise))
        # else:
        #     n_noise = self.n_noise
        n_noise = self.n_noise
        if self.objective == 'delta':
            title = f'dataset: {self.dataset_name}, delta method: {self.method}, n_noise={n_noise}, k={self.k}, temp_method: {self.temp_name}'
        else:
            title = f'dataset: {self.dataset_name}, objective: {self.objective}, n_noise={n_noise}, k={self.k}'
        plt.title(title)
        ax2.set_xlabel('iter')
        fig.legend(bbox_to_anchor=(1, 0.22), bbox_transform=ax1.transAxes)
        plt.savefig(root + '/' + f'ad_method_{ad_method}-dataset_{self.dataset_name}-{self.objective}-delta method_{self.method}-n_noise={n_noise}-k_{self.k}-temp_method_{self.temp_name}.png')

    def plot_results_auc(self, ad_method, log=False):
        aucs = np.array(self.aucs)
        f1s = np.array(self.f1s)
        deltas = np.array(self.deltas)
        sort = aucs.argsort()

        plt.cla()

        if log:
            plt.plot(aucs[sort], np.log10(deltas[sort]), label='gap')
            plt.ylabel('log-scale')
        else:
            plt.plot(aucs[sort], deltas[sort], label='gap')
        plt.plot(aucs[sort], f1s[sort], label='f1')

        # if self.n_noise < 1:
        #     n_noise = int(self.train_x.shape[0] * self.n_noise / (1 + self.n_noise))
        # else:
        #     n_noise = self.n_noise
        n_noise = self.n_noise
        if self.objective == 'delta':
            title = f'dataset: {self.dataset_name}, delta method: {self.method}, n_noise={n_noise}, k={self.k}, temp_method: {self.temp_name}'
        else:
            title = f'dataset: {self.dataset_name}, objective: {self.objective}, n_noise={n_noise}, k={self.k}'
        plt.title(title)
        plt.xlabel('auc')
        plt.legend()
        plt.savefig(f'./ad_method_{ad_method}-dataset_{self.dataset_name}-{self.objective}-delta method_{self.method}-n_noise={n_noise}-k_{self.k}-temp_method_{self.temp_name}.png')

    def save_result(self, ad_method):
        f_name = self.save_path + '/' + f'data/visual-method-{ad_method}_dataset-{self.dataset_name}'
        save = {
            'scores': self.scores,
            'embs': self.embs,
            'hps': self.hps,
        }
        torch.save(save, f_name)

if __name__ == "__main__":
    pass
    

    
