import torch
from typing import List, Tuple, Dict
import os
from ..data.dataset import GraphFlowSamplesDataset


# ---------- exact Gillespie sim on grid (for testing only; not for large M) ----------
@torch.no_grad()
def gillespie_on_grid(
    Q_seq: List[torch.Tensor],     # list length T of (n,n) generators
    t_grid: torch.Tensor,          # (T+1,), increasing times
    x0: torch.Tensor,              # (M,), initial states in {0,...,n-1}
    seed: int | None = None,
) -> torch.Tensor:
    """
    Returns states_at_grid: (T+1, M) with states at each grid time.
    Q_seq[k] is frozen on [t_k, t_{k+1}). Works on CPU or CUDA.
    """
    assert len(Q_seq) == t_grid.numel() - 1, "Need one Q per interval"
    device = Q_seq[0].device
    x0 = x0.to(device=device, dtype=torch.long)
    T = len(Q_seq)
    n = Q_seq[0].shape[0]
    M = x0.numel()

    states_at_grid = torch.empty((T+1, M), dtype=torch.long, device=device)
    states_at_grid[0] = x0.clone()

    gen = torch.Generator(device=device)
    if seed is not None:
        gen.manual_seed(seed)

    for m in range(M):                      # loop particles (simple & clear)
        x = int(x0[m].item())
        for k in range(T):                  # loop time intervals
            t = float(t_grid[k].item())
            t_end = float(t_grid[k+1].item())
            Q = Q_seq[k]                    # freeze generator on this interval

            # Precompute off-diagonal row sums if diag not set exactly
            # (but prefer supplying Q with rows summing to 0)
            while t < t_end:
                # total rate out of current state
                lam = (-Q[x, x]).item()
                if lam <= 0.0:
                    # no jump possible in this interval
                    break

                # waiting time ~ Exp(lam) via inverse-CDF
                u = torch.rand((), generator=gen, device=device).item()
                dt = -torch.log(torch.tensor(1.0 - u)).item() / lam
                if t + dt >= t_end:
                    # no jump before grid boundary; carry state forward
                    break

                # pick destination proportional to off-diagonal rates
                row = Q[x].clone()
                row[x] = 0.0
                row = row.clamp_min(0)
                s = row.sum().item()
                if s <= 0.0:
                    # numerically no valid destination; treat as no jump
                    break
                probs = (row / s)
                j = torch.multinomial(probs, 1, replacement=True, generator=gen).item()

                # perform jump
                t += dt
                x = int(j)

            states_at_grid[k+1, m] = x

    return states_at_grid

# ---------- helpers for particle sim & eval ----------
@torch.no_grad()
def sample_initial_states(rho0: torch.Tensor, n_particles: int, generator: torch.Generator | None = None) -> torch.Tensor:
    """
    Draw initial particle states from initial histogram rho0 (length n).
    Returns: states long tensor of shape (n_particles,)
    """
    rho0 = rho0 / rho0.sum()
    # torch.multinomial supports batched draws; here we draw n_particles from a 1D probs vector. :contentReference[oaicite:1]{index=1}
    idx = torch.multinomial(rho0, n_particles, replacement=True, generator=generator)
    return idx.to(torch.long)

@torch.no_grad()
def tau_leap_step_batch(
    states: torch.Tensor, dt: float,
    K: torch.Tensor, pi: torch.Tensor, rho: torch.Tensor,
    V: torch.Tensor, beta: float,
    eps: float = 1e-16, generator: torch.Generator | None = None
) -> torch.Tensor:
    """
    Vectorized tau-leap update for a batch of particles.
    Only calls multinomial on rows with a valid distribution.
    """
    # Build rates and gather per-particle rows
    Q = build_rates(K, pi, rho, V, beta, eps=eps)          # (n,n) >= 0
    rates = Q.index_select(0, states)                      # (M,n)
    lam = rates.sum(dim=1)                                 # (M,)

    # Jump / no-jump
    p_jump = 1.0 - torch.exp(-lam * dt)                    # (M,)
    jump_draw = torch.bernoulli(p_jump, generator=generator).to(torch.bool)
    jump_mask = jump_draw & (lam > 0)

    if jump_mask.any():
        idx = jump_mask.nonzero(as_tuple=False).squeeze(1)         # indices of jumpers
        rates_sub = rates.index_select(0, idx)                      # (m,n)
        lam_sub = lam.index_select(0, idx).unsqueeze(1)             # (m,1)

        # Safe probabilities: clamp, renormalize; rowsum is >0 since lam_sub>0
        probs_sub = (rates_sub / lam_sub).clamp_min(0)
        probs_sub = probs_sub / probs_sub.sum(dim=1, keepdim=True)

        # Extra safety: if any numeric rowsum ~0 (shouldn't happen), fall back to uniform
        rowsum = probs_sub.sum(dim=1, keepdim=True)
        bad = (rowsum <= 0) | (~torch.isfinite(probs_sub).all(dim=1, keepdim=True))
        if bad.any():
            n = probs_sub.shape[1]
            probs_sub[bad.expand_as(probs_sub)] = 1.0 / n

        dest_sub = torch.multinomial(probs_sub, 1, replacement=True,
                                     generator=generator).squeeze(1)  # (m,)

        # Apply only to jumpers
        states = states.clone()
        states[idx] = dest_sub

    return states

@torch.no_grad()
def simulate_particles_tau_leap(
    K: torch.Tensor, 
    pi: torch.Tensor, 
    V: torch.Tensor, 
    beta: float,
    T,
    n,           # (T+1, n) density snapshots from your density simulator
    t_grid: torch.Tensor,
    p0,# (T,) times, increasing
    n_particles: int,
    seed: int | None = None,
    eps: float = 1e-16,
    device: str = "cpu",
    dtype = torch.float64,
    rho_seq = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Simulate M particles in lockstep over t_grid using tau-leaping driven by rho_seq[t].
    Returns:
      states_over_time: (T, M) particle states at each time (t=0..T)
      emp_rho_seq:      (T, n) empirical histograms at each time (normalized)
    """

    gen = torch.Generator(device=device)
    if seed is not None:
        gen.manual_seed(seed)

    # t=0: sample from rho_0
    states = sample_initial_states(p0, n_particles, generator=gen)   # (M,)
    states_over_time = torch.empty((T, n_particles), dtype=dtype, device=device)
    emp_rho_seq = torch.empty((T, n), device=device)

    # record t=0
    states_over_time[0] = states
    emp_rho_seq[0] = torch.bincount(states, minlength=n).to(dtype) / n_particles

    # loop over time steps (lockstep tau-leap using provided rho_k)
    for k in range(T-1):
        dt = float(t_grid[k+1] - t_grid[k])
        rho_k = rho_seq[k] * pi           
        # rho_k = emp_rho_seq[k]           # only current density is used for the step
        states = tau_leap_step_batch(states, dt, K, pi, rho_k, V, beta, eps=eps, generator=gen)

        # record at t_{k+1}
        states_over_time[k+1] = states
        emp_rho_seq[k+1] = torch.bincount(states, minlength=n).to(dtype) / n_particles

    return states_over_time, emp_rho_seq

@torch.no_grad()
def compare_empirical_to_density(
    emp_rho_seq: torch.Tensor, rho_seq: torch.Tensor, eps: float = 1e-12
) -> Dict[str, torch.Tensor]:
    """
    Compute per-time discrepancies between empirical histograms and given densities.
    Returns dict with 'TV' and 'KL_true||emp' (both shape (T+1,)).
    """
    emp = emp_rho_seq.clamp_min(eps)
    true = rho_seq.clamp_min(eps)
    tv = 0.5 * (emp - true).abs().sum(dim=1)                           # (T+1,)
    kl_true_emp = (true * (true.log() - emp.log())).sum(dim=1)         # KL(true || emp)
    return {"TV": tv, "KL_true||emp": kl_true_emp}


# ---------- your originals ----------
def logmean(a, b, eps=1e-16):
    """Logarithmic mean θ(a,b) = (a-b)/(log a - log b); θ(a,a)=a."""
    a = a.clamp_min(eps); b = b.clamp_min(eps)
    num = a - b
    den = (a.log() - b.log())
    theta = torch.where(den.abs() > 1e-16, num/den, a)
    return theta

def build_rates(K, rho, V, beta, eps=1e-16):
    """
    Inputs: K (n,n), pi (n,), rho (n,), V (n,), beta (scalar)
    Returns: Q_off (n,n) with off-diagonal rates q_{i->j}; diag is zero.
    """


    # log-mean θ(ρ_i, ρ_j)
    rho_i = rho.unsqueeze(1)            # (n,1)
    rho_j = rho.unsqueeze(0)            # (1,n)
    theta = logmean(rho_i, rho_j, eps)  # (n,n)

    # potentials ψ = V + β log ρ
    psi = V + beta * torch.log(rho.clamp_min(eps))

    # drive [ψ_j - ψ_i]_+
    drive = (psi.unsqueeze(1) - psi.unsqueeze(0))  # (i,j): ψ_j - ψ_i
    pos = torch.relu(drive)

    w = K * theta / rho_i

    # rates q_{i->j}
    Q_off = w * pos
    Q_off = Q_off - torch.diag(Q_off.sum(dim=1))  # set diagonal so rows sum to 0

    # Q_off = Q_off.transpose(0,1)  # want Q[i,j] = rate from i to j

    return Q_off

def build_rates_from_network(K, gradV, beta, rho, eps=1e-16):
    """
    Inputs: K (n,n), gradV (n,), beta (scalar), emp_prob (n,)
    Returns: Q_off (n,n) with off-diagonal rates q_{i->j}; diag is zero.
    """

    # log-mean θ(ρ_i, ρ_j)
    rho_i = rho.unsqueeze(1)            # (n,1)
    rho_j = rho.unsqueeze(0)            # (1,n)
    theta = logmean(rho_i, rho_j, eps)  # (n,n)


    drive = gradV + beta * ((torch.log(rho_i)).clamp_min(eps) - (torch.log(rho_j)).clamp_min(eps))
    pos = torch.relu(drive)

    w = K * theta / rho_i

    # rates q_{i->j}
    Q_off = w * pos
    Q_off = Q_off - torch.diag(Q_off.sum(dim=1))  # set diagonal so rows sum to 0

    # Q_off = Q_off.transpose(0,1)  # want Q[i,j] = rate from i to j
    return Q_off

@torch.no_grad()
def transition_matrix_from_Q(Q, dt):
    """
    Compute P = exp(dt * Q). For a generator Q (rows sum to 0),
    P is (numerically) row-stochastic; we clamp/renorm for safety.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    P = torch.linalg.matrix_exp(dt * Q).to(device=device)
    P = P.clamp_min(0)
    P = P / P.sum(dim=1, keepdim=True).clamp_min(1e-16)
    return P

@torch.no_grad()
def sample_categorical_rows(probs, generator=None):
    """
    Row-wise categorical sampling.
    probs: (B, n) where each row is a distribution.
    returns: (B,) int64 indices.
    """
    u = torch.rand(probs.size(0), generator=generator, device=probs.device).unsqueeze(1)
    cdf = probs.cumsum(dim=1)
    idx = (u > cdf).sum(dim=1)
    return idx.clamp(max=probs.size(1) - 1).to(torch.int64)

# --- single step: forecast from t0 -> t1 ------------------------------------

@torch.no_grad()
def forecast_step(Q, t0, t1, states, *, eval_point="mid", dtype=torch.float64):
    """
    Advance all samples from t0 to t1 by freezing Q on the interval and sampling with P = exp(dt*Q).
    states: (B,) int64 in {0..n-1}.  eval_point: 'left'|'mid'|'right'.
    """
    dt = float(t1 - t0)
    if dt <= 0:
        return states

    te = t0 if eval_point == "left" else (t1 if eval_point == "right" else 0.5 * (t0 + t1))
    P = transition_matrix_from_Q(Q, dt)          # (n, n)
    row_probs = P.index_select(0, states)        # (B, n)
    new_states = sample_categorical_rows(row_probs)
    return new_states

# --- main loop over the time grid -------------------------------------------

@torch.no_grad()
def simulate_on_grid(K, V, beta, pi, rho_seq, times, init_states, *, eval_point="left", dtype=torch.float64, seed=None, gt_prob=False):
    """
    times: (T,) increasing grid; init_states: (B,) int64.
    Returns S: (T, B) integer states at each grid time.
    """
    times = torch.as_tensor(times, dtype=dtype, device=init_states.device)

    # Following assertion seems a bit too much
    # assert torch.all(times[1:] >= times[:-1]), "times must be sorted"

    if seed is not None:
        torch.manual_seed(int(seed))

    states = init_states.to(torch.int64).clone()
    T, B = times.numel(), states.numel()
    S = torch.empty((T, B), dtype=torch.int64, device=states.device)
    S[0] = states

    for k in range(T - 1):
        if gt_prob:
            states = torch.multinomial(rho_seq[k] * pi, B, replacement=True)

            # Legacy rate, kind of wrongly computed because it takes rho from the GT sequence
            # Q = build_rates(K, rho_seq[k], V, beta, eps=1e-16)

        # WE NEED TO GET RHO FROM STATES
        emp_prob = torch.bincount(states, minlength=K.size(0)).to(dtype) / B
        Q = build_rates(K, emp_prob, V, beta, eps=1e-16)          # (n,n)


        states = forecast_step(Q, float(times[k].item()), float(times[k+1].item()),
                               states, eval_point=eval_point, dtype=dtype)
        S[k + 1] = states
    return S


@torch.no_grad()
def simulate_model_on_grid(model, K, pi, rho_seq, times, init_states, *, eval_point="left", dtype=torch.float64, seed=None, gt_prob=False):
    """
    times: (T,) increasing grid; init_states: (B,) int64.
    Returns S: (T, B) integer states at each grid time.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    times = torch.as_tensor(times, dtype=dtype, device=init_states.device)

    # Following assertion seems a bit too much
    # assert torch.all(times[1:] >= times[:-1]), "times must be sorted"

    if seed is not None:
        torch.manual_seed(int(seed))

    states = init_states.to(device=device, dtype=torch.int64).clone()
    T, B = times.numel(), states.numel()
    S = torch.empty((T, B), dtype=torch.int64, device=states.device)
    S[0] = states

    for k in range(T - 1):
        if gt_prob:
            states = torch.multinomial(rho_seq[k] * pi, B, replacement=True)

            # Legacy rate, kind of wrongly computed because it takes rho from the GT sequence
            # Q = build_rates(K, rho_seq[k], V, beta, eps=1e-16)

        # WE NEED TO GET RHO FROM STATES
        emp_prob = torch.bincount(states, minlength=K.size(0)).to(device = device, dtype = dtype) / B  # (n,n)

        gradV, beta = model.get_potential()  # get from your trained model
        Q = build_rates_from_network(K, gradV, beta, emp_prob, eps=1e-16)

        states = forecast_step(Q, float(times[k].item()), float(times[k+1].item()),
                               states, eval_point=eval_point, dtype=dtype)
        S[k + 1] = states
    return S

# --- (optional) diagnostics: empirical distribution from samples -------------

@torch.no_grad()
def empirical_distributions(S, n, dtype=torch.float64):
    """
    S: (T,B) -> (T,n) empirical probabilities.
    """
    T, B = S.shape
    out = torch.zeros((T, n), device=S.device, dtype=dtype)
    ones = torch.ones(B, device=S.device, dtype=dtype)
    for t in range(T):
        out[t].scatter_add_(0, S[t], ones)
    return out / B


# ---------- example usage ----------
if __name__ == "__main__":

    data_folder = './data/synthetic_example/'

    if data_folder[0] == '~':
        data_folder = os.path.expanduser(data_folder)
    data_tr = GraphFlowSamplesDataset.from_folder(os.path.join(data_folder, 'train'), dtype=torch.float64)

    n_particles = 100000
    n = data_tr.n
    T = data_tr.T
    device = ("cuda" if torch.cuda.is_available() else "cpu")
    torch.set_default_dtype(torch.float64)

    K = torch.tensor(data_tr.metadata.get('K', None), dtype=data_tr.dtype, device=device)  # (n,n)
    pi = torch.tensor(data_tr.metadata.get('pi', None), dtype=data_tr.dtype, device=device)  # (n,)
    V = torch.tensor(data_tr.metadata.get('V', None), dtype=data_tr.dtype, device=device)  # (n,)
    beta = data_tr.metadata.get('beta', None)
    p0 = torch.tensor(data_tr.metadata.get('p0', None), dtype=data_tr.dtype, device=device)  # (n,)
    dtype = data_tr.dtype

    rho_seq = torch.tensor(data_tr.rho_gt_seq, dtype=dtype, device=device)  # (T+1, n)

    times = torch.linspace(0.0, 1.0, T+1, device=device)

    prob_seq = rho_seq * pi

    debug_gt_prob = False

    p0 = prob_seq[0]
    initial_states = torch.multinomial(p0, n_particles, replacement=True)

    S = simulate_on_grid(K, V, beta, pi, rho_seq, times, initial_states, eval_point="left", seed=123, gt_prob=debug_gt_prob)

    tests = 10
    for i in range(tests):
        index = torch.randint(0, T, (1,)).item()
        print(f"Empyrical prob at time {index}", torch.bincount(S[index], minlength=n).to(torch.float64) / n_particles)
         # print(f"Empyrical prob at time {index}, unnormalized", torch.bincount(S[index], minlength=n).to(torch.float64))
        print(f"GroundTruth prob at time {index}", prob_seq[index])
        print('\n')


    p_emp = empirical_distributions(S, n)  # optional
