import json
from typing import List
import random
from datetime import datetime
import pandas as pd
import numpy as np
import string
import networkx as nx
import torch
from enum import Enum
from tqdm import tqdm
from rt.synthetic.config import SCMParams


def set_random_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def assign_cluster_at_levels(num_nodes: int, hierarchy: List):
    num_base_clusters = np.prod(hierarchy)
    nodes_per_cluster = int(np.ceil(num_nodes / num_base_clusters))
    base_cluster_offsets = [
        nodes_per_cluster * (c_idx + 1) for c_idx in range(num_base_clusters - 1)
    ]
    base_cluster_offsets.append(num_nodes)
    cluster_at_levels = np.zeros((num_nodes, len(hierarchy)), dtype=int)
    cluster_node_idx_start = 0
    for c_idx in range(num_base_clusters):
        for l_idx in range(len(hierarchy)):
            fac = np.prod(hierarchy[l_idx + 1 :]) if l_idx + 1 < len(hierarchy) else 1
            cluster_node_idx_end = base_cluster_offsets[c_idx]
            cluster_at_levels[cluster_node_idx_start:cluster_node_idx_end, l_idx] = (
                c_idx // fac
            ) % hierarchy[l_idx]
        cluster_node_idx_start = cluster_node_idx_end
    return cluster_at_levels


def get_probs_at_levels(hierarchy_a: List, hierarchy_b: List):
    assert len(hierarchy_a) == len(
        hierarchy_b
    ), "only similar hierarchy levels are supported"

    num_levels = len(hierarchy_a)
    probs_at_levels = []
    for l_idx in range(num_levels):
        shape = (hierarchy_a[l_idx], hierarchy_b[l_idx])
        probs = np.random.uniform(0.001, 0.002, size=shape)
        for i in range(max(shape)):
            probs[i % shape[0], i % shape[1]] = 0.9
        probs_at_levels.append(probs)
    return probs_at_levels


def get_nodes_connect_prob(
    node_idx_a: int,
    node_idx_b: int,
    probs_at_levels: List,
    cluster_at_levels_a: List,
    cluster_at_levels_b: List,
):
    num_levels = len(probs_at_levels)
    probs = [
        probs_at_levels[l_idx][
            cluster_at_levels_a[node_idx_a, l_idx],
            cluster_at_levels_b[node_idx_b, l_idx],
        ]
        for l_idx in range(num_levels)
    ]
    return np.prod(probs)


def get_bipartite_hsbm(size_a: int, size_b: int, hierarchy_a: List, hierarchy_b: List):
    assert len(hierarchy_a) == len(
        hierarchy_b
    ), "only similar hierarchy levels are supported"

    cluster_at_levels_a = assign_cluster_at_levels(
        num_nodes=size_a, hierarchy=hierarchy_a
    )
    cluster_at_levels_b = assign_cluster_at_levels(
        num_nodes=size_b, hierarchy=hierarchy_b
    )
    probs_at_levels = get_probs_at_levels(
        hierarchy_a=hierarchy_a, hierarchy_b=hierarchy_b
    )

    bi_hsbm = nx.DiGraph()

    nodes_a = [f"a{i}" for i in range(size_a)]
    nodes_b = [f"b{j}" for j in range(size_b)]

    for a_idx, a_node in enumerate(nodes_a):
        bi_hsbm.add_node(
            a_node,
            node_idx=a_idx,
            hierarchy=list(cluster_at_levels_a[a_idx]),
        )
    for b_idx, b_node in enumerate(nodes_b):
        bi_hsbm.add_node(
            b_node,
            node_idx=b_idx,
            hierarchy=list(cluster_at_levels_b[b_idx]),
        )

    for b_idx, b_node in tqdm(enumerate(nodes_b), desc="adding edges in bi_hsbm"):
        probs = np.array(
            [
                get_nodes_connect_prob(
                    node_idx_a=a_idx,
                    node_idx_b=b_idx,
                    probs_at_levels=probs_at_levels,
                    cluster_at_levels_a=cluster_at_levels_a,
                    cluster_at_levels_b=cluster_at_levels_b,
                )
                for a_idx in range(size_a)
            ]
        )
        try:
            probs = probs / probs.sum()
        except ValueError:
            probs = None
        a_idx = np.random.choice(range(size_a), p=probs)
        bi_hsbm.add_edge(nodes_a[a_idx], b_node)

    assert nx.is_bipartite(bi_hsbm)
    return bi_hsbm


def get_bipartite_pl(size_a: int, size_b: int, exponent: float):
    bi_pl = nx.DiGraph()

    nodes_a = [f"a{i}" for i in range(size_a)]
    nodes_b = [f"b{j}" for j in range(size_b)]

    for a_idx, a_node in enumerate(nodes_a):
        bi_pl.add_node(a_node, node_idx=a_idx, bipartite=0)
    for b_idx, b_node in enumerate(nodes_b):
        bi_pl.add_node(b_node, node_idx=b_idx, bipartite=1)

    shuffled_a_indxs = np.arange(size_a)
    np.random.shuffle(shuffled_a_indxs)
    for b_idx, b_node in tqdm(enumerate(nodes_b), desc="adding edges in bi_pl"):
        probs = np.array(
            [
                1 - np.pow(shuffled_a_indxs[a_idx] / size_a, exponent)
                for a_idx in range(size_a)
            ]
        )
        try:
            probs = probs / probs.sum()
        except ValueError:
            probs = None
        a_idx = np.random.choice(range(size_a), p=probs)
        bi_pl.add_edge(nodes_a[a_idx], b_node)

    assert nx.is_bipartite(bi_pl)

    return bi_pl


class Snapshot:
    def __init__(self):
        self.data = {}

    def capture(self, k, v):
        self.data[k] = v

    def get(self, k):
        return self.data[k]

    def _serialize(self, value):
        if isinstance(value, Snapshot):
            return {k: self._serialize(v) for k, v in value.data.items()}
        elif isinstance(value, dict):
            return {k: self._serialize(v) for k, v in value.items()}
        elif isinstance(value, (datetime, pd.Timestamp)):
            return value.isoformat()
        elif isinstance(value, np.generic):
            return value.item()
        elif isinstance(value, list):
            return [self._serialize(v) for v in value]
        else:
            return value

    def __str__(self):
        serialized_data = {k: self._serialize(v) for k, v in self.data.items()}
        return json.dumps(serialized_data, indent=2)


class TableType(Enum):
    Entity = "entity"
    Activity = "activity"


class MLP:
    def __init__(
        self,
        scm_params: SCMParams,
        in_dim: int,
        hid_dim: int,
        out_dim: int,
        num_layers: int = 2,
    ):
        assert num_layers >= 1
        self.scm_params = scm_params
        dims = [in_dim] + [hid_dim] * (num_layers - 1) + [out_dim]
        self.weights = [torch.empty(dims[i], dims[i + 1]) for i in range(num_layers)]
        for W in self.weights:
            init_fn = scm_params.initialization_choices.sample_uniform()
            init_fn(W)

    def __call__(self, x):
        for W in self.weights[:-1]:
            act_fn = self.scm_params.activation_choices.sample_uniform()
            x = act_fn(x @ W)
        return x @ self.weights[-1]


class CategoricalEncoder:
    def __init__(
        self,
        scm_params: SCMParams,
        num_embeddings: int,
        embedding_dim: int,
        num_layers: int = 2,
    ):
        self.E = torch.nn.Embedding(
            num_embeddings=num_embeddings, embedding_dim=embedding_dim, _freeze=True
        )
        init_fn = scm_params.initialization_choices.sample_uniform()
        init_fn(self.E.weight)
        self.mlp = MLP(
            scm_params=scm_params,
            in_dim=embedding_dim,
            hid_dim=embedding_dim,
            out_dim=embedding_dim,
            num_layers=num_layers,
        )

    def __call__(self, x: torch.LongTensor):
        return self.mlp(self.E(x))


class CategoricalDecoder:
    def __init__(
        self,
        scm_params: SCMParams,
        num_embeddings: int,
        embedding_dim: int,
        num_layers: int = 2,
    ):
        self.E = torch.nn.Embedding(
            num_embeddings=num_embeddings, embedding_dim=embedding_dim, _freeze=True
        )
        init_fn = scm_params.initialization_choices.sample_uniform()
        init_fn(self.E.weight)
        self.mlp = MLP(
            scm_params=scm_params,
            in_dim=embedding_dim,
            hid_dim=embedding_dim,
            out_dim=embedding_dim,
            num_layers=num_layers,
        )

    def __call__(self, x: torch.Tensor):
        x = self.mlp(x)
        sims = self.E.weight @ x.T
        return torch.argmax(sims, dim=0)
