import copy
from collections import Counter
from typing import List

import torch.nn as nn
import networkx as nx
import numpy as np
import scipy.sparse as sp
import wandb
from torchmetrics import MeanMetric, MaxMetric, Metric, MeanAbsoluteError
import torch
from torch import Tensor

from ConStruct.utils import PlaceHolder
from ConStruct.metrics.metrics_utils import (
    counter_to_tensor,
    wasserstein1d,
    total_variation1d,
)
from ConStruct.metrics.spectre_utils import SpectreSamplingMetrics, is_planar_graph
from ConStruct.datasets.tls_dataset import CellGraph
from ConStruct.analysis.dist_helper import compute_mmd, emd, gaussian_tv

from z3 import Optimize, Bool, Int, And, Or, Not, If, Sum, PbEq, sat, PbLe, PbGe, Abs
import planarity
import random

def is_planar_fast(G):
    # planarity.is_planar 接受 NetworkX 图
    ok = planarity.is_planar(G)
    if ok:
        return True, None
    # 不平面时返回 kuratowski 子图（NetworkX 子图）
    kur = planarity.kuratowski_subgraph(G)
    return False, kur

def smt_planarize(G, gamma, phenotype_attr='phenotype'):
    all_degrees = [degree for _, degree in G.degree()]
    max_degree = max(all_degrees)
    if nx.check_planarity(G)[0] and max_degree >= gamma:
        return G

    opt = Optimize()
    opt.set('timeout', 1000*1)
    edge_vars = {}
    add_edge = False

    if max_degree < gamma:
        for u in G.nodes():
            for v in G.nodes():
                if u != v:
                    e = tuple(sorted((u, v)))
                    node_type_u = G.nodes[u][phenotype_attr]
                    node_type_v = G.nodes[v][phenotype_attr]
                    if node_type_u == 'B' and node_type_v == 'B':
                        edge_vars[e] = Bool(f"keep_{e[0]}_{e[1]}")
                        if G.has_edge(u, v):
                            opt.add_soft(edge_vars[e])
                        else:
                            opt.add_soft(Not(edge_vars[e]))
        add_edge = True
    else:
        for u, v in G.edges():
            e = tuple(sorted((u, v)))
            edge_vars[e] = Bool(f"keep_{e[0]}_{e[1]}")
            opt.add_soft(edge_vars[e])

    from collections import defaultdict
    incident_map = defaultdict(list)
    for (u, v), b in edge_vars.items():
        incident_map[u].append(b)
        incident_map[v].append(b)
    maxdeg_matches = [If((Sum([If(b, 1, 0) for b in incident_map[u]]) >= gamma), 1, 0) for u in incident_map.keys()]
    opt.add(Sum(maxdeg_matches) >= 1)

    edges = list(edge_vars.keys())
    count = 0
    while True:
        if opt.check() != sat:
            # print('no solution for BB')
            return None
        m = opt.model()
        keep_set = {e for e, var in edge_vars.items() if m.evaluate(var)}
        delete_set = set(edges) - keep_set
        H = copy.deepcopy(G)
        if add_edge:
            H.add_edges_from(set(edges))
        H.remove_edges_from(delete_set)
        is_planar, kur = is_planar_fast(H)
        if is_planar:
            # print('find a solution BB')
            return H
        elif count > 200:
            if not add_edge:
                H = copy.deepcopy(G)
                H.clear_edges()
                for e in edges:
                    H.add_edge(*e)
                    if not nx.check_planarity(H)[0]:
                        H.remove_edge(*e)
                all_degrees = [degree for node, degree in H.degree()]
                max_degree = max(all_degrees)
                if max_degree >= gamma:
                    return H
            return None
        else:
            count += 1
            kur_edges = frozenset(tuple(sorted(e)) for e in kur.edges())
            opt.add(PbGe([(Not(edge_vars[e]), 1) for e in kur_edges], 1))

def smt_planarize_by_k(G, original_G):
    if is_planar_graph(G):
        return G
    opt = Optimize()
    opt.set('timeout', 1000*1)
    assert set(G.nodes()) == set(original_G.nodes())
    edge_vars = {}
    kept_edges = []
    for u, v in G.edges():
        e = tuple(sorted((u, v)))
        edge_vars[e] = Bool(f"keep_{e[0]}_{e[1]}")
        if original_G.has_edge(u, v):
            opt.add(edge_vars[e])
            kept_edges.append((u, v))

    edges = list(edge_vars.keys())
    H = nx.Graph()
    H.add_nodes_from(G.nodes())
    for u, v in kept_edges:
        H.add_edge(u, v)
    assert nx.check_planarity(H)[0]

    for e in list(set(edges) - set(kept_edges)):
        H.add_edge(*e)
        if nx.check_planarity(H)[0]:
            kept_edges.append(e)
        else:
            H.remove_edge(*e)
    initial_k = len(edges) - len(kept_edges)
    greedy_delete = set(edges) - set(kept_edges)
    best_solution = greedy_delete

    current_k = initial_k - 1
    conflict_counter = Counter()
    while current_k >= 0:
        opt.add(PbLe([(Not(edge_vars[e]), 1) for e in edges], current_k))
        added_kur_constraints = set()
        count = 0
        no_solution = False
        while True:
            if opt.check() != sat:
                # print('no solution')
                no_solution = True
                break
            m = opt.model()
            keep_set = {e for e, var in edge_vars.items() if m.evaluate(var)}
            delete_set = set(edges) - keep_set
            H = copy.deepcopy(G)

            H.remove_edges_from(delete_set)
            is_planar, kur = is_planar_fast(H)
            if is_planar:
                # print('find a solution')
                best_solution = delete_set
                best_k = len(delete_set)
                current_k = best_k - 1
                break
            elif count > 200:
                # print('reach max iter')
                no_solution = True
                break
            else:
                count += 1
                kur_edges = frozenset(tuple(sorted(e)) for e in kur.edges())
                for e in kur_edges:
                    conflict_counter[e] += 1
                if kur_edges not in added_kur_constraints:
                    opt.add(PbGe([(Not(edge_vars[e]), 1) for e in kur_edges], 1))
                    added_kur_constraints.add(kur_edges)
                else:
                    continue
        if no_solution:
            break
    G_planar = copy.deepcopy(G)
    G_planar.remove_edges_from(best_solution)

    return G_planar

from typing import Dict, Tuple, List
def compute_b_gamma_indices(G: nx.Graph, phenotype_attr: str = "phenotype", a_max: int = 5) -> Dict:
    b_gamma = {}
    for n, data in G.nodes(data=True):
        p = data.get(phenotype_attr)
        if p == "B":
            # count neighbors whose phenotype == 'B' (exclude self)
            b_neighbors = sum(1 for nb in G.neighbors(n) if G.nodes[nb].get(phenotype_attr) == "B")
            b_gamma[n] = b_neighbors if b_neighbors <= a_max else a_max
    return b_gamma

def smt_tls(G: CellGraph, original_G: CellGraph, tls_type, a_max=5, phenotype_attr="phenotype", add_edge=False, 
            k_0=None, k_1=None, k_2=None, k_3=None, k_4=None):
    opt = Optimize()
    opt.set('timeout', 1000*1)
    assert set(G.nodes()) == set(original_G.nodes())
    edge_vars = {}
    candidate_edges = []
    if not add_edge:
        for u, v in G.edges():
            e = tuple(sorted((u, v)))
            node_type_u = G.nodes[u][phenotype_attr]
            node_type_v = G.nodes[v][phenotype_attr]
            if node_type_u != node_type_v:
                edge_vars[e] = Bool(f"keep_{e[0]}_{e[1]}")
                assert not original_G.has_edge(u, v)
                candidate_edges.append(e)
                opt.add_soft(edge_vars[e])
            else:
                assert original_G.has_edge(u, v)
    else:
        for u in G.nodes():
            for v in G.nodes():
                e = tuple(sorted((u, v)))
                node_type_u = G.nodes[u][phenotype_attr]
                node_type_v = G.nodes[v][phenotype_attr]
                if (node_type_u == 'B' and node_type_v == 'T') or (node_type_u == 'T' and node_type_v == 'B'):
                    edge_vars[e] = Bool(f"keep_{e[0]}_{e[1]}")
                    assert not original_G.has_edge(u, v)
                    candidate_edges.append(e)
                    if G.has_edge(u, v):
                        opt.add_soft(edge_vars[e])
                    else:
                        opt.add_soft(Not(edge_vars[e]))
    edges = list(edge_vars.keys())
    b_gamma = compute_b_gamma_indices(G, phenotype_attr=phenotype_attr, a_max=a_max)
    edge_to_gamma = {}
    for u, v in candidate_edges:
        pu = G.nodes[u].get(phenotype_attr)
        pv = G.nodes[v].get(phenotype_attr)
        if pu == "B" and pv == "T":
            b = u; t = v
        elif pv == "B" and pu == "T":
            b = v; t = u
        else:
            raise ValueError('impossible condition')
        assert b in b_gamma
        a = b_gamma.get(b)
        edge_to_gamma[tuple(sorted((b, t)))] = a
    Gamma_a = {}
    for a in range(a_max + 1):
        # sum e for edges whose bucket == a
        members = [If(edge_vars[e], 1, 0) for e, aa in edge_to_gamma.items() if aa == a and e in edge_vars]
        if len(members) != 0:
            Ga = Int(f"Gamma_{a}")
            opt.add(Ga == Sum(*members))
        else:
            Ga = Int(f"Gamma_{a}")
            opt.add(Ga == 0)
        Gamma_a[a] = Ga
    # total Gamma
    Gamma = Int("Gamma_total")
    opt.add(Gamma == Sum(*[Gamma_a[a] for a in range(a_max + 1)]))
    # require at least one gamma edge 
    opt.add(Gamma >= 0)

    if tls_type == "high_tls":
        k2_threshold = 0.05
        SCALE = 10000
        # k2 constraint: (Gamma - (G0+G1+G2)) / Gamma > k2_threshold
        G012 = Sum(Gamma_a[0], Gamma_a[1], Gamma_a[2])
        Gamma_other = Int("Gamma_other")
        opt.add(Gamma_other == Gamma - G012)
        # integerized strict inequality:
        opt.add(Gamma_other * SCALE > int(k2_threshold * SCALE) * Gamma)

    elif tls_type == "low_tls":
        k1_threshold = 0.05
        SCALE = 10000
        # k1 constraint: (Gamma - (G0+G1)) / Gamma > k1_threshold
        G01 = Sum(Gamma_a[0], Gamma_a[1])
        Gamma_other = Int("Gamma_other")
        opt.add(Gamma_other == Gamma - G01)
        # integerized strict inequality:
        opt.add(Or(Gamma_other * SCALE < int(k1_threshold * SCALE) * Gamma, Gamma == 0))
    else:
        raise ValueError(f"Invalid tls_type : {tls_type} (TLS Sampling Metrics)")
    if k_0 is not None:
        k_0 = float(k_0)
        SCALE = 10000
        G0 = Sum(Gamma_a[0])
        Gamma_other = Int("Gamma_other0")
        opt.add(Gamma_other == Gamma - G0)
        # integerized strict inequality:
        opt.minimize(Abs(Gamma_other * SCALE - int(k_0 * SCALE) * Gamma))
    if k_1 is not None:
        k_1 = float(k_1)
        SCALE = 10000
        G01 = Sum(Gamma_a[0], Gamma_a[1])
        Gamma_other = Int("Gamma_other01")
        opt.add(Gamma_other == Gamma - G01)
        # integerized strict inequality:
        opt.minimize(Abs(Gamma_other * SCALE - int(k_1 * SCALE) * Gamma))
    if k_2 is not None:
        k_2 = float(k_2)
        SCALE = 10000
        G012 = Sum(Gamma_a[0], Gamma_a[1], Gamma_a[2])
        Gamma_other = Int("Gamma_other012")
        opt.add(Gamma_other == Gamma - G012)
        # integerized strict inequality:
        opt.minimize(Abs(Gamma_other * SCALE - int(k_2 * SCALE) * Gamma))
    if k_3 is not None:
        k_3 = float(k_3)
        SCALE = 10000
        G0123 = Sum(Gamma_a[0], Gamma_a[1], Gamma_a[2], Gamma_a[3])
        Gamma_other = Int("Gamma_other0123")
        opt.add(Gamma_other == Gamma - G0123)
        # integerized strict inequality:
        opt.minimize(Abs(Gamma_other * SCALE - int(k_3 * SCALE) * Gamma))
    if k_4 is not None:
        k_4 = float(k_4)
        SCALE = 10000
        G01234 = Sum(Gamma_a[0], Gamma_a[1], Gamma_a[2], Gamma_a[3], Gamma_a[4])
        Gamma_other = Int("Gamma_other01234")
        opt.add(Gamma_other == Gamma - G01234)
        # integerized strict inequality:
        opt.minimize(Abs(Gamma_other * SCALE - int(k_4 * SCALE) * Gamma))

    count = 0
    while True:
        if opt.check() != sat:
            if not add_edge:
                return smt_tls(G, original_G, tls_type, a_max, phenotype_attr, add_edge=True)
            else:
                return None
        m = opt.model()
        keep_set = {e for e, var in edge_vars.items() if m.evaluate(var)}
        H = copy.deepcopy(G)
        if add_edge:
            H.add_edges_from(set(edges))
        delete_set = set(edges) - keep_set
        H.remove_edges_from(delete_set)
        is_planar, kur = is_planar_fast(H)
        if is_planar:
            return H
        elif count > 100:
            # print('reach max iteration')
            return None
        else:
            count += 1
            kur_edges = frozenset(tuple(sorted(e)) for e in kur.edges())
            opt.add(PbGe([(Not(edge_vars[e]), 1) for e in kur_edges if e in edge_vars.keys()], 1))

def refine_tls(cell_graph : CellGraph, tls_type, a_max=5):
    gamma = 3 if tls_type == "high_tls" else 0
    def tls_check(tls_features, tls_type):
        if tls_type == "high_tls":
            return 0.05 < tls_features["k_2"]
        elif tls_type == "low_tls":
            return tls_features["k_1"] < 0.05
        else:
            raise ValueError(f"Invalid tls_type : {tls_type} (TLS Metrics)")
    G = copy.deepcopy(cell_graph)
    tls_features = G.compute_tls_features(verbose=False)
    if tls_check(tls_features, tls_type) and is_planar_graph(G):
        return G
    for edge in G.edges():
        edge_type = G.classify_TLS_edge(edge)
        start_node, end_node = edge
        G[start_node][end_node].update({'type': edge_type})

    check = False
    for e in G.edges():
        u, v = e
        edge_type = G.get_edge_data(u, v)['type']
        if edge_type.startswith("gamma"):
            check = True
    if not check and tls_type == 'high_tls':
        for u in G.nodes():
            node_type = G.nodes[u]['phenotype']
            if node_type == 'B':
                if random.random() > 1/3:
                    G.nodes[u]["phenotype"] = 'T'
        for edge in G.edges():
            edge_type = G.classify_TLS_edge(edge)
            start_node, end_node = edge
            G[start_node][end_node].update({'type': edge_type})

    H = copy.deepcopy(G)
    H.clear_edges()
    for e in G.edges():
        u, v = e
        node_type_u = G.nodes[u]['phenotype']
        node_type_v = G.nodes[v]['phenotype']
        # if edge_type.startswith("gamma") :
        if node_type_u == 'B' and node_type_v == 'B':
            H.add_edge(u, v)

    H_temp = smt_planarize(H, gamma=gamma)
    if H_temp is None:
        # print('check 1 fail')
        return None

    K = copy.deepcopy(H_temp)
    assert nx.check_planarity(K)[0]
    for e in G.edges():
        u, v = e
        edge_type = G.get_edge_data(u, v)['type']
        if edge_type.startswith("gamma"):
            K.add_edge(u, v)

    K_temp = smt_tls(K, H_temp, tls_type, a_max=a_max,
                k_0=tls_features['k_0'], k_1=tls_features['k_1'], k_2=tls_features['k_2'],
                k_3=tls_features['k_3'], k_4=tls_features['k_4'])
    if K_temp is None:
        # print('check 2 fail')
        return None

    tls_features_new = K_temp.compute_tls_features(verbose=False)
    assert tls_check(tls_features_new, tls_type) and nx.check_planarity(K_temp)[0]
    T = copy.deepcopy(K_temp)
    for e in G.edges():
        u, v = e
        edge_type = G.get_edge_data(u, v)['type']
        node_type_u = G.nodes[u]['phenotype']
        node_type_v = G.nodes[v]['phenotype']
        if not (edge_type.startswith("gamma") or (node_type_u == 'B' and node_type_v == 'B')):
            T.add_edge(u, v)

    T_temp = smt_planarize_by_k(T, K_temp)
    if T_temp is None:
        # print('check 3 fail')
        return None
    # tls_features_new = T.compute_tls_features(verbose=False)
    # assert tls_check(tls_features_new, tls_type) and is_planar_graph(T)
    if tls_type == 'high_tls':
        assert T_temp.has_high_TLS() and is_planar_graph(T_temp)
    elif tls_type == 'low_tls':
        assert T_temp.has_low_TLS() and is_planar_graph(T_temp)
    else:
        raise ValueError(f"Invalid tls_type : {tls_type} (TLS Metrics)")
    return T_temp

class TLSSamplingMetrics(SpectreSamplingMetrics):
    def __init__(self, train_dataloader, val_dataloader, tls_type):
        super().__init__(
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            compute_emd=False,
            metrics_list=[
                "degree",
                "clustering",
                "orbit",
                "spectre",
                "wavelet",
                "planar",
            ],
        )
        self.train_cell_graphs = self.loader_to_cell_graphs(train_dataloader)
        self.val_cell_graphs = self.loader_to_cell_graphs(val_dataloader)

        self.cell_graph_valid_fn = self.is_cell_graph_valid(tls_type=tls_type)

        self.mean_tls_validity = MeanMetric()
        self.tls_type = tls_type
        self.a_max = 5

    def is_cell_graph_valid(self, tls_type):
        if tls_type == "high_tls":
            return lambda cg: cg.has_high_TLS() and is_planar_graph(cg)
        elif tls_type == "low_tls":
            return lambda cg: cg.has_low_TLS() and is_planar_graph(cg)
        else:
            raise ValueError(f"Invalid tls_type : {tls_type} (TLS Sampling Metrics)")

    def loader_to_cell_graphs(self, loader):
        cell_graphs = []
        for batch in loader:
            for tg_graph in batch.to_data_list():
                cell_graph = CellGraph.from_torch_geometric(tg_graph)
                cell_graphs.append(cell_graph)

        return cell_graphs

    def forward(self, generated_graphs: list, current_epoch, local_rank):
        to_log = super().forward(generated_graphs, current_epoch, local_rank)
        if local_rank == 0:
            print(
                f"Computing TLS sampling metrics between {sum([placeholder.X.shape[0] for placeholder in generated_graphs])} generated graphs and {len(self.val_graphs)}"
            )

        generated_cell_graphs = []
        for batch in generated_graphs:
            graph_placeholders = batch.split()
            for placeholder in graph_placeholders:
                cell_graph = CellGraph.from_placeholder(placeholder)
                generated_cell_graphs.append(cell_graph)
        
        generated_cell_graphs_refine = []
        # for _, cell_graph in enumerate(generated_cell_graphs):
        #     cell_graph_temp = refine_tls(cell_graph, tls_type=self.tls_type, a_max=self.a_max)
        #     if cell_graph_temp is not None:
        #         generated_cell_graphs_refine.append(copy.deepcopy(cell_graph_temp))
        #     else:
        #         generated_cell_graphs_refine.append(copy.deepcopy(cell_graph))
        # if len(generated_cell_graphs_refine) == 0:
        #     generated_cell_graphs_refine = generated_cell_graphs

        # TLS features
        if local_rank == 0:
            print("Computing TLS features stats...")

        device = self.mean_tls_validity.device
        self.mean_tls_validity(
            tls_validity_ratio(generated_cell_graphs, self.cell_graph_valid_fn).to(
                device
            )
        )
        to_log["tls_metrics/mean_tls_validity"] = (
            self.mean_tls_validity.compute().item()
        )
        tls_stats = compute_tls_stats(
            generated_cell_graphs,
            self.val_cell_graphs,
            bins=100,
            compute_emd=self.compute_emd,
        )
        for key, value in tls_stats.items():
            to_log[f"tls_metrics/{key}"] = value
            if wandb.run:
                wandb.run.summary[f"tls_metrics/{key}"] = value

        # Isomorphic vs unique?
        if local_rank == 0:
            print("Computing uniqueness and isomorphic for cell graphs...")
            frac_novel = eval_fraction_novel_cell_graphs(
                generated_cell_graphs=generated_cell_graphs,
                train_cell_graphs=self.train_cell_graphs,
            )
            (
                frac_unique,
                frac_unique_and_novel,
                frac_unique_and_novel_valid,
            ) = eval_fraction_unique_novel_valid_cell_graphs(
                generated_cell_graphs=generated_cell_graphs,
                train_cell_graphs=self.train_cell_graphs,
                valid_cg_fn=self.cell_graph_valid_fn,
            )
        to_log.update(
            {
                "tls_metrics/frac_novel": frac_novel,
                "tls_metrics/frac_unique": frac_unique,
                "tls_metrics/frac_unique_and_novel": frac_unique_and_novel,
                "tls_metrics/frac_unique_and_novel_valid": frac_unique_and_novel_valid,
            }
        )

        if local_rank == 0:
            tls_sampling_metrics_log = {
                metric: value
                for metric, value in to_log.items()
                if "tls_metrics" in metric
            }
            print(f"TLS sampling statistics: {tls_sampling_metrics_log}")

        return to_log

    def reset(self):
        self.mean_tls_validity.reset()
        super().reset()


def tls_validity_ratio(generated_graphs: List[PlaceHolder], valid_cg_fn):
    cg_validities = [int(valid_cg_fn(cg)) for cg in generated_graphs]
    return torch.tensor(cg_validities)


# specific for cell graphs (isomorphism function is of cell graphs)
def eval_fraction_novel_cell_graphs(generated_cell_graphs, train_cell_graphs):
    count_non_novel = 0
    for gen_cg in generated_cell_graphs:
        for train_cg in train_cell_graphs:
            if nx.faster_could_be_isomorphic(train_cg, gen_cg):
                if gen_cg.is_isomorphic(train_cg):
                    count_non_novel += 1
                    break
    return 1 - count_non_novel / len(generated_cell_graphs)


# specific for cell graphs (isomorphism function is of cell graphs)
def eval_fraction_unique_novel_valid_cell_graphs(
    generated_cell_graphs,
    train_cell_graphs,
    valid_cg_fn,
):
    count_non_unique = 0
    count_not_novel = 0
    count_not_valid = 0
    for cg_idx, gen_cg in enumerate(generated_cell_graphs):
        is_unique = True
        for gen_cg_seen in generated_cell_graphs[:cg_idx]:
            if nx.faster_could_be_isomorphic(gen_cg_seen, gen_cg):
                # we also need to consider phenotypes of nodes
                if gen_cg.is_isomorphic(gen_cg_seen):
                    count_non_unique += 1
                    is_unique = False
                    break
        if is_unique:
            is_novel = True
            for train_cg in train_cell_graphs:
                if nx.faster_could_be_isomorphic(train_cg, gen_cg):
                    if gen_cg.is_isomorphic(train_cg):
                        count_not_novel += 1
                        is_novel = False
                        break
            if is_novel:
                if not valid_cg_fn(gen_cg):
                    count_not_valid += 1

    frac_unique = 1 - count_non_unique / len(generated_cell_graphs)
    frac_unique_non_isomorphic = frac_unique - count_not_novel / len(
        generated_cell_graphs
    )
    frac_unique_non_isomorphic_valid = (
        frac_unique_non_isomorphic - count_not_valid / len(generated_cell_graphs)
    )

    return (
        frac_unique,
        frac_unique_non_isomorphic,
        frac_unique_non_isomorphic_valid,
    )


def compute_tls_stats(generated_cell_graphs, val_cell_graphs, bins, compute_emd):
    """Compute TLS features for a set of graphs.

    Args:
        generated_cell_graphs (list): List of CellGraphs to compute the TLS features.
        val_cell_graphs (list): List of CellGraphs to compute the TLS features.

    Returns:

    """

    # Extract TLS features
    generated_tls_hists = cell_graphs_to_TLS_features_hists(generated_cell_graphs, bins)
    val_tls_features_hists = cell_graphs_to_TLS_features_hists(val_cell_graphs, bins)

    # Compute TLS features stats
    tls_stats = {}
    for key in generated_tls_hists.keys():
        generated_sample = [generated_tls_hists[key]]
        val_sample = [val_tls_features_hists[key]]
        if compute_emd:
            mmd_dist = compute_mmd(
                val_sample,
                generated_sample,
                kernel=emd,
            )
        else:
            mmd_dist = compute_mmd(
                val_sample,
                generated_sample,
                kernel=gaussian_tv,
            )
        tls_stats[key] = mmd_dist

    return tls_stats


def cell_graphs_to_TLS_features_hists(cell_graphs: List[CellGraph], bins):
    # Compute TLS features
    tls_features_list = []
    for cell_graph in cell_graphs:
        tls_features = cell_graph.compute_tls_features()
        tls_features_list.append(tls_features)

    # Group TLS features across k
    tls_features_grouped = {}
    for key in tls_features_list[0].keys():
        tls_features_grouped[key] = [
            tls_features[key] for tls_features in tls_features_list
        ]

    # Generate histograms
    tls_hists = {}
    for key in tls_features_grouped.keys():
        values_list = tls_features_grouped[key]
        tls_hists[key], _ = np.histogram(
            values_list, bins=bins, range=(0, 1), density=False
        )

    return tls_features_grouped
