import torch
import pickle
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np
import sys
sys.path.append("..")
from params import args
from model import AnyGraph
from data_handler import DataHandler

def load_model(path, device='cuda:0'):
    model = AnyGraph()
    model.load_state_dict(torch.load(f'/your_saved_dir/Models/{path}.pt'))
    model = model.to(device)

    with open(f'/your_saved_dir/History/{path}.his', 'rb') as fs:
        metrics = pickle.load(fs)
    print('Model Loaded')

    return model

def extract_expert_outputs(model, data_handler, sample_size=2000, device='cuda:0'):
    model.eval()
    feats = data_handler.projectors.to(device)
    outputs = []
    labels = []
    with torch.no_grad():
        for i, expert in enumerate(model.experts):
            output = expert.forward(feats)
            outputs.extend(output.cpu().numpy())
            labels.extend([i] * output.shape[0])

    outputs = np.array(outputs)
    labels = np.array(labels)
    if len(outputs) > sample_size:
        np.random.seed(123)
        indices = np.random.choice(len(outputs), size=sample_size, replace=False)
        outputs = outputs[indices]
        labels = labels[indices]

    return outputs, labels

def plot_tsne(outputs, labels, save_path='tsne_plot.png'):
    tsne = TSNE(n_components=2, random_state=0)
    tsne_results = tsne.fit_transform(outputs)

    plt.figure(figsize=(8, 8))
    scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], s=10, c=labels, cmap='tab10')
    plt.legend(handles=scatter.legend_elements()[0], labels=[f'Expert {i}' for i in range(len(set(labels)))])
    plt.title('T-SNE of Expert Outputs')
    plt.xlabel('T-SNE 1')
    plt.ylabel('T-SNE 2')
    plt.savefig(save_path)
    plt.show()

def plot_tsne_multiple(datasets, titles, save_path='tsne_all.pdf'):
    plt.figure(figsize=(18, 6))
    for idx, (outputs, labels) in enumerate(datasets):
        tsne = TSNE(n_components=2, random_state=0)
        tsne_results = tsne.fit_transform(outputs)

        ax = plt.subplot(1, 3, idx + 1)
        scatter = ax.scatter(tsne_results[:, 0], tsne_results[:, 1], s=10, c=labels, cmap='tab10')
        ax.set_title(titles[idx], fontsize=20)
        # ax.set_xlabel('T-SNE 1')
        # ax.set_ylabel('T-SNE 2')
        ax.tick_params(axis='both', which='major', labelsize=16)
        legend_loc = 'lower right' if titles[idx] == 'Products-Home' else 'best'
        ax.legend(handles=scatter.legend_elements()[0],
                  labels=[f'Expert {i}' for i in range(len(set(labels)))],
                  loc=legend_loc, fontsize=12)
    plt.tight_layout()
    plt.savefig(save_path, format='pdf', bbox_inches='tight')
    plt.show()

def extract_topk_expert_outputs(model, data_handler, device='cuda:0'):
    model.eval()
    feats = data_handler.projectors.to(device)
    experts, _ = model.summon(0)
    outputs = []
    with torch.no_grad():
        for i, expert in enumerate(experts):
            output = expert.forward(feats)
            outputs.append(output.cpu().numpy())
    return outputs

def compute_expert_sim():
    args.devices = ['cuda:0', 'cuda:0']
    args.topk_expert = 2
    print(args)
    link1 = [
        'products_tech', 'yelp2018', 'yelp_textfeat', 'products_home', 'steam_textfeat', 'amazon_textfeat',
        'amazon-book', 'citation-2019', 'citation-classic', 'pubmed', 'citeseer', 'ppa', 'p2p-Gnutella06',
        'soc-Epinions1', 'email-Enron',
    ]
    model = load_model('your_model_dir', 'cuda:0')
    cosine_sim = 0
    for data in link1:
        handler = [DataHandler(data)]
        model.assign_experts(handler, reca=False)
        outputs = extract_topk_expert_outputs(model, handler[0], device=args.devices[0])
        cosine_sim += torch.nn.functional.cosine_similarity(torch.from_numpy(outputs[0]),
                                                            torch.from_numpy(outputs[1])).mean().item()
    print(cosine_sim / len(link1))

if __name__ == '__main__':
    args.devices = ['cuda:0', 'cuda:0']
    print(args)
    model = load_model('your_model_dir', 'cuda:0')
    handler_cora = DataHandler('cora')
    handler_home = DataHandler('products_home')
    handler_protein = DataHandler('proteins_spec1')
    outputs_cora, labels_cora = extract_expert_outputs(model, handler_cora, sample_size=2500)
    outputs_home, labels_home = extract_expert_outputs(model, handler_home, sample_size=2500)
    outputs_protein, labels_protein = extract_expert_outputs(model, handler_protein, sample_size=2500)
    datasets = [
        (outputs_cora, labels_cora),
        (outputs_home, labels_home),
        (outputs_protein, labels_protein)
    ]
    titles = ['Cora', 'Products-Home', 'Proteins-1']

    plot_tsne_multiple(datasets, titles, save_path='tsne_cora_home_protein.pdf')




    # experts = model.experts
    # cosine_sim = 0
    # l2_dist = 0
    # for name, param in experts[0].named_parameters():
    #     # if 'weight' in name:
    #     print(name, param.shape)

    # for i in range(len(experts)):
    #     for j in range(i+1, len(experts)):
    #         for name, param in experts[i].named_parameters():
    #             if 'linear.weight' in name:
    #                 cosine_sim += torch.nn.functional.cosine_similarity(param, experts[j].state_dict()[name]).mean().item()
    #                 l2_dist += torch.dist(param, experts[j].state_dict()[name]).mean().item()
    #
    # print(cosine_sim / (len(experts) * (len(experts) - 1) / 2))
    # print(l2_dist / (len(experts) * (len(experts) - 1) / 2))
