import dhg
import torch
import numpy as np
from collections import defaultdict


class MultiExpMetric:
    def __init__(self):
        self.res = defaultdict(list)

    def update(self, res):
        for k, v in res.items():
            self.res[k].append(v)

    def __str__(self, ):
        ret = []
        for k, v in self.res.items():
            v = np.array(v)
            ret.append(f"\t{k} -> {(np.array(v) - np.mean(np.array(v))).tolist()}")
        for k, v in self.res.items():
            v = np.array(v)
            ret.append(f"\t{k} -> {v.mean():.5f} - {v.std():.5f}")
        return '\n'.join(ret)


def product_split(train_mask, val_mask, test_mask, test_ind_ratio):
    train_idx, val_idx, test_idx = torch.where(train_mask)[0], torch.where(val_mask)[0], torch.where(test_mask)[0]
    test_idx_shuffle = torch.randperm(len(test_idx))
    num_ind = int(len(test_idx) * test_ind_ratio)
    test_ind_idx, test_tran_idx = test_idx[test_idx_shuffle[:num_ind]], test_idx[test_idx_shuffle[num_ind:]]
    obs_idx = torch.cat([train_idx, val_idx, test_tran_idx]).numpy().tolist()

    num_obs, num_train, num_val = len(obs_idx), len(train_idx), len(val_idx)
    test_ind_mask = torch.zeros_like(train_mask, dtype=torch.bool)
    obs_train_mask = torch.zeros(num_obs, dtype=torch.bool)
    obs_val_mask = torch.zeros(num_obs, dtype=torch.bool)
    obs_test_mask = torch.zeros(num_obs, dtype=torch.bool)

    test_ind_mask[test_ind_idx] = True
    obs_train_mask[:num_train] = True
    obs_val_mask[num_train:num_train+num_val] = True
    obs_test_mask[num_train+num_val:] = True
    return obs_idx, obs_train_mask, obs_val_mask, obs_test_mask, test_ind_mask 


def re_index(vec):
    res = vec.clone()
    raw_id, new_id = res[0].item(), 0
    for idx in range(len(vec)):
        if vec[idx].item() != raw_id:
            raw_id, new_id = vec[idx].item(), new_id + 1
        res[idx] = new_id
    return res


def sub_hypergraph(hg: dhg.Hypergraph, v_idx):
    v_map = {v: idx for idx, v in enumerate(v_idx)}
    v_set = set(v_idx)
    e_list, w_list = [], []
    for e, w in zip(*hg.e):
        new_e = []
        for v in e:
            if v in v_set:
                new_e.append(v_map[v])
        if len(new_e) >= 1:
            e_list.append(tuple(new_e))
            w_list.append(w)
    return dhg.Hypergraph(len(v_set), e_list, w_list)
