from typing import Literal, Optional, Tuple, Dict

import numpy as np
import torch


def simulate_random_dag(d: int, degree: int, rng: np.random.Generator) -> np.ndarray:
    """
    Generate a random Directed Acyclic Graph (DAG) adjacency matrix.

    Args:
        d (int): Number of nodes (variables) in the graph.
        degree (int): Expected average degree (number of edges per node).
        rng (np.random.Generator): Numpy random number generator for reproducibility.

    Returns:
        np.ndarray: Binary adjacency matrix of shape (d, d) representing the DAG structure.
    """
    p = float(degree) / max(1, d - 1)
    undirected = (rng.uniform(size=(d, d)) < p).astype(np.float32)
    undirected = np.triu(undirected, 1)
    perm = rng.permutation(d)
    P = np.eye(d)[perm]
    B_bin = P.T @ undirected @ P
    return B_bin


def nonlinear_parent_map(x_par: np.ndarray, sem_type: Literal['mlp', 'mim'], rng: np.random.Generator) -> np.ndarray:
    """
    Compute the nonlinear mapping from parent variables to a child variable using a specified SEM type.

    Args:
        x_par (np.ndarray): Array of shape (n, k) with n samples and k parent variables.
        sem_type (Literal['mlp', 'mim']): Type of nonlinear SEM ('mlp' for multilayer perceptron, 'mim' for mixed interaction model).
        rng (np.random.Generator): Numpy random number generator for reproducibility.

    Returns:
        np.ndarray: Array of shape (n,) with the computed values for the child variable.
    """
    n, k = x_par.shape
    if k == 0:
        return np.zeros(n, dtype=np.float32)
    if sem_type == 'mlp':
        hidden = 64
        W1 = rng.uniform(low=0.5, high=2.0, size=(k, hidden)).astype(np.float32)
        W1 *= np.where(rng.uniform(size=W1.shape) < 0.5, -1.0, 1.0)
        W2 = rng.uniform(low=0.5, high=2.0, size=(hidden,)).astype(np.float32)
        W2 *= np.where(rng.uniform(size=W2.shape) < 0.5, -1.0, 1.0)
        h = 1.0 / (1.0 + np.exp(-(x_par @ W1)))
        return (h @ W2).astype(np.float32)
    elif sem_type == 'mim':
        w1 = rng.uniform(low=0.5, high=2.0, size=k).astype(np.float32)
        w1 *= np.where(rng.uniform(size=k) < 0.5, -1.0, 1.0)
        w2 = rng.uniform(low=0.5, high=2.0, size=k).astype(np.float32)
        w2 *= np.where(rng.uniform(size=k) < 0.5, -1.0, 1.0)
        w3 = rng.uniform(low=0.5, high=2.0, size=k).astype(np.float32)
        w3 *= np.where(rng.uniform(size=k) < 0.5, -1.0, 1.0)
        return (np.tanh(x_par @ w1) + np.cos(x_par @ w2) + np.sin(x_par @ w3)).astype(np.float32)
    else:
        raise ValueError('Unknown sem_type')


def generate_sem_with_domains(
    n_per_domain: int,
    d: int,
    degree: int,
    *,
    sem_type: Literal['linear', 'mlp', 'mim'] = 'mlp',
    noise_type: Literal['gaussian'] = 'gaussian',
    domain_scales: Optional[np.ndarray] = None,
    seed: Optional[int] = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]:
    """
    Generate synthetic data from a Structural Equation Model (SEM) with multiple domains for causal discovery.

    Args:
        n_per_domain (int): Number of samples per domain.
        d (int): Number of variables (nodes).
        degree (int): Expected average degree of the DAG.
        sem_type (Literal['linear', 'mlp', 'mim'], optional): Type of SEM ('linear', 'mlp', or 'mim'). Default is 'mlp'.
        noise_type (Literal['gaussian'], optional): Type of noise to use. Default is 'gaussian'.
        domain_scales (Optional[np.ndarray], optional): Array of noise scales for each domain. If None, defaults to [1.0, 3.0, 0.5].
        seed (Optional[int], optional): Random seed for reproducibility. Default is 1.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]:
            - X (torch.Tensor): Data matrix of shape (n_per_domain * num_domains, d).
            - U (torch.Tensor): Domain indicator matrix of shape (n_per_domain * num_domains, num_domains).
            - B_bin (torch.Tensor): Binary adjacency matrix of the DAG (d, d).
            - N_eps (torch.Tensor): Noise matrix of shape (n_per_domain * num_domains, d).
            - meta (Dict): Dictionary with metadata about the generated SEM and domains.
    """
    if seed is not None:
        rng = np.random.default_rng(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
    else:
        rng = np.random.default_rng()

    B_bin = simulate_random_dag(d, degree, rng)
    if domain_scales is None:
        domain_scales = np.array([1.0, 3.0, 0.5], dtype=np.float32)

    X_all, U_all, N_all = [], [], []
    order = np.argsort(B_bin.sum(axis=0))
    std = np.random.uniform(low=.5, high=2, size=d)
    info = dict()
    for e_id in range(3):
        noise_scale = domain_scales[e_id]
        X = np.zeros((n_per_domain, d), dtype=np.float32)
        noise = np.zeros((n_per_domain, d), dtype=np.float32)
        for j in order:
            parents = np.where(B_bin[:, j] != 0)[0]
            if noise_type == 'gaussian':
                noise_variance = std[j]*noise_scale
                eps = rng.normal(loc=0.0, scale=noise_variance, size=n_per_domain).astype(np.float32)
            else:
                raise ValueError('Unsupported noise_type')
            if sem_type == 'linear':
                X[:, j] = (X[:, parents] @ (B_bin[parents, j] * 1.0)).astype(np.float32) + eps
            else:
                X[:, j] = nonlinear_parent_map(X[:, parents], sem_type, rng) + eps
            noise[:, j] = eps
        X_all.append(X)
        u = np.zeros((n_per_domain, 3), dtype=np.float32)
        u[:, e_id] = 1.0
        U_all.append(u)
        N_all.append(noise)
        info[e_id] = {"mu": 0, "Sigma": np.diag((std*domain_scales[e_id])**2)} # store true params (oracle)

    X = np.array(X_all)
    U = np.array(U_all)
    N_eps = np.array(N_all)

    meta = {'d': d, 'degree': degree, 'sem_type': sem_type, 'noise_type': noise_type, 'domain_scales': domain_scales.tolist(), 'seed': seed}
    return X, U, B_bin, N_eps, meta, info


