import dhg
import torch
import numpy as np
from collections import defaultdict
from dhg.data import Cora, Pubmed, Citeseer
from dhg.data import CoauthorshipCora, CoauthorshipDBLP
from dhg.data import CocitationCora, CocitationPubmed, CocitationCiteseer
from dhg.data import News20, DBLP4k, IMDB4k, Recipe100k, Recipe200k
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE


class MultiExpMetric:
    def __init__(self):
        self.t = defaultdict(list)
        self.s = defaultdict(list)

    def update(self, res):
        self._update(self.t, res['t'])
        self._update(self.s, res['s'])

    def _update(self, data, new_res):
        for k, v in new_res.items():
            data[k].append(v)

    def __str__(self, ):
        ret = []
        ret.append('Teacher:')
        for k, v in self.t.items():
            v = np.array(v)
            ret.append(f"\t{k} -> {v.mean():.5f} - {v.std():.5f}")
        ret.append('Student:')
        for k, v in self.s.items():
            v = np.array(v)
            ret.append(f"\t{k} -> {v.mean():.5f} - {v.std():.5f}")
        return '\n'.join(ret)

import scipy.sparse as sp
from ogb.nodeproppred import NodePropPredDataset

 
def load_data(name):
    print(name)
    if name == 'cora':
        data = Cora()
        edge_list = data['edge_list']
    elif name == 'pubmed':
        data = Pubmed()
        edge_list = data['edge_list'] 
    elif name == 'citeseer':
        data = Citeseer()
        edge_list = data['edge_list'] 
    elif name == 'ca_cora':
        data = CoauthorshipCora()
        edge_list = data['edge_list']
    elif name == 'coauthorship_dblp':
        data = CoauthorshipDBLP()
        edge_list = data['edge_list']
    elif name == 'cc_cora':
        data = CocitationCora()
        edge_list = data['edge_list']
    elif name == 'cc_citeseer':
        data = CocitationCiteseer()
        edge_list = data['edge_list']
    elif name == 'news20':
        data = News20()
        edge_list = data['edge_list']
    elif name == 'dblp4k_paper':
        data = DBLP4k()
        edge_list = data['edge_by_paper']
    elif name == 'dblp4k_term':
        data = DBLP4k()
        edge_list = data['edge_by_term']
    elif name == 'dblp4k_conf':
        data = DBLP4k()
        edge_list = data['edge_by_conf']
    elif name == 'imdb_aw':
        data = IMDB4k()
        edge_list = data['edge_by_actor'] + data['edge_by_director']
    elif name == 'recipe_100k':
        data = Recipe100k()
        edge_list = data['edge_list']
    elif name == 'recipe_200k':
        data = Recipe200k()
        edge_list = data['edge_list']
    else:
        raise NotImplementedError
    return data, edge_list


def product_split(train_mask, val_mask, test_mask, test_ind_ratio):
    train_idx, val_idx, test_idx = torch.where(train_mask)[0], torch.where(val_mask)[0], torch.where(test_mask)[0]
    test_idx_shuffle = torch.randperm(len(test_idx))
    num_ind = int(len(test_idx) * test_ind_ratio)
    test_ind_idx, test_tran_idx = test_idx[test_idx_shuffle[:num_ind]], test_idx[test_idx_shuffle[num_ind:]]
    obs_idx = torch.cat([train_idx, val_idx, test_tran_idx]).numpy().tolist()

    num_obs, num_train, num_val = len(obs_idx), len(train_idx), len(val_idx)
    test_ind_mask = torch.zeros_like(train_mask, dtype=torch.bool)
    test_tran_mask = torch.zeros_like(train_mask, dtype=torch.bool)
    obs_train_mask = torch.zeros(num_obs, dtype=torch.bool)
    obs_val_mask = torch.zeros(num_obs, dtype=torch.bool)
    obs_test_mask = torch.zeros(num_obs, dtype=torch.bool)

    test_ind_mask[test_ind_idx] = True
    test_tran_mask[test_tran_idx] = True
    obs_train_mask[:num_train] = True
    obs_val_mask[num_train:num_train+num_val] = True
    obs_test_mask[num_train+num_val:] = True
    return obs_idx, obs_train_mask, obs_val_mask, obs_test_mask, test_ind_mask , test_tran_mask


def re_index(vec):
    res = vec.clone()
    raw_id, new_id = res[0].item(), 0
    for idx in range(len(vec)):
        if vec[idx].item() != raw_id:
            raw_id, new_id = vec[idx].item(), new_id + 1
        res[idx] = new_id
    return res


def sub_hypergraph(hg: dhg.Hypergraph, v_idx):
    v_map = {v: idx for idx, v in enumerate(v_idx)}
    v_set = set(v_idx)
    e_list, w_list = [], []
    for e, w in zip(*hg.e):
        new_e = []
        for v in e:
            if v in v_set:
                new_e.append(v_map[v])
        if len(new_e) >= 1:
            e_list.append(tuple(new_e))
            w_list.append(w)
    return dhg.Hypergraph(len(v_set), e_list, w_list)


def fix_iso_v(G: dhg.Hypergraph):
    # fix isolated vertices
    iso_v = np.array(G.deg_v)==0
    if np.any(iso_v):
        extra_e = [tuple([e, ]) for e in np.where(iso_v)[0]]
        G.add_hyperedges(extra_e)
    return G


def ho_topology_score(X, G: dhg.Hypergraph):
    if isinstance(G, dhg.Graph):
        G = dhg.Hypergraph.from_graph(G)
    e_s = []
    X_e = G.v2e(X, aggr='mean')
    for e_idx in range(G.num_e):
        cur_s = []
        for v_idx in G.nbr_v(e_idx):
            cur_s.append(torch.norm(X_e[e_idx] - X[v_idx], p=2).item())
        e_s.append(np.mean(cur_s))
    return np.mean(e_s)

def ho_topology_scores_my( model, X, G, noise_level=1.0, tau=1.0):
    model.eval()
    pred = model(X, G).softmax(dim=-1).detach()
    entropy_x = -(pred * pred.log()).sum(1, keepdim=True)
    entropy_x[entropy_x.isnan()] = 0
    entropy_e = G.v2e(entropy_x, aggr="mean")

    X_noise = X.clone() * (torch.randn_like(X) + 1) * noise_level
    pred_ = model(X_noise, G).softmax(dim=-1).detach()
    entropy_x_ = -(pred_ * pred_.log()).sum(1, keepdim=True)
    entropy_x_[entropy_x_.isnan()] = 0
    entropy_e_ = G.v2e(entropy_x_, aggr="mean")

    delta_e_ = (entropy_e_ - entropy_e).abs()
    delta_e_ = 1 - delta_e_ / delta_e_.max()
    delta_e_ = delta_e_.squeeze()
    return delta_e_

def hyperedge_divergence(G, teacher_probs, student_probs):
    """
    Calculate KL divergence between teacher and student for each hyperedge
    """
    divergences = []
    if isinstance(G, dhg.Graph):
        G = dhg.Hypergraph.from_graph(G)
    e_s = []
    X_et = G.v2e(teacher_probs, aggr='mean')
    X_es = G.v2e(student_probs, aggr='mean')
    kl = torch.norm(X_et - X_es, p=2, dim=1)  # [E]

    return kl.cpu().numpy()

def visualize_rho_vs_divergence(rhos, divergences, output="rho_div.png"):
    plt.figure(figsize=(8,6))
    plt.scatter(rhos, divergences, alpha=0.6)
    plt.xlabel("Hyperedge reliability score (ρ)")
    plt.ylabel("KL divergence (teacher vs student)")
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.savefig(output, bbox_inches="tight", pad_inches=0)
    plt.close()





def pot(graph, label, m, num_iter, num_cls):
    f_l = F.one_hot(label, num_classes=num_cls).float()#.cuda()
    P = torch.exp(-graph @ f_l.T / 4)
    for i in range(num_iter):
        row_sum = P.sum(dim=1)#.cuda()
        scale = torch.maximum(row_sum, torch.tensor(1.0, device='cpu'))
        P = P / scale.unsqueeze(1)
        P = P * m / P.sum()
    return P


def sinkhorn(graph, label, num_classes,mask_class, g1):
    if isinstance(graph, list):
        graph = torch.stack(graph) 
    if len(graph[0])> num_classes:
        graph = torch.softmax(graph[:, mask_class], dim=1)
    f_l = F.one_hot(label, num_classes=num_classes).float()#.cuda()
    K = torch.exp(-graph @ f_l.T / 4)
    v = torch.ones(K.shape[1], device='cpu')
    a = torch.ones(K.shape[0], device='cpu')
    b = torch.ones(K.shape[1], device='cpu')

    for i in range(5):
        u = torch.div(a, K @ v + 1e-6)
        v = torch.div(b, K.T @ u + 1e-6)
    return torch.diag(u) @ K @ torch.diag(v)
 
def draw_HeatMap(cm, labels=None, output_path="./confusion_matrix_heatmap.png"):
    plt.figure(figsize=(10, 8))  # larger figure
    ax = sns.heatmap(
        cm,
        annot=False,          # set True if you want values inside
        fmt='d',
        cmap='Blues',
        cbar=True,
        annot_kws={"size":18}
    )
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=18)   # <--- bigger font size for colorbar
    # cbar.set_label("Count", fontsize=22, labelpad=15)  # optional label


    plt.xlabel('Predicted Labels', fontsize=24, labelpad=24)
    plt.ylabel('True Labels', fontsize=24, labelpad=15)
    plt.title('Confusion Matrix Heatmap', fontsize=26, pad=20)

    if labels is not None:
        plt.xticks(ticks=np.arange(len(labels))+0.5, labels=labels, fontsize=18, rotation=45, ha="right")
        plt.yticks(ticks=np.arange(len(labels))+0.5, labels=labels, fontsize=18, rotation=0)
    else:
        plt.xticks(fontsize=18, rotation=45)
        plt.yticks(fontsize=18)
    plt.tight_layout()
    plt.savefig(output_path, dpi=400)
    plt.close()


def visualize(h, color , output_path= "./TSNE.png"):
    z = TSNE(n_components = 2).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize = (10, 8))
    plt.xticks([])
    plt.yticks([])

    plt.scatter(z[:, 0], z[:, 1], s = 70, c = color, cmap = "Set2")
    # plt.show()
    plt.savefig(output_path,bbox_inches='tight', pad_inches=0)
    plt.close()

def append_number(filename: str, number: int):
    with open(filename, "a") as f:
        f.write(str(number) + "\n")

from typing import Optional, Literal, Tuple


