import torch
import torch.nn as nn
import torch.nn.functional as F


class SiLU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class PositionalEmbedding(nn.Module):
    def __init__(self, num_channels, max_positions=10000, endpoint=False):
        super().__init__()
        self.num_channels = num_channels
        self.max_positions = max_positions
        self.endpoint = endpoint

    def forward(self, x):
        freqs = torch.arange(start=0, end=self.num_channels // 2,
                             dtype=torch.float32, device=x.device)
        freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
        freqs = (1 / self.max_positions) ** freqs
        x = x.ger(freqs.to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        return x


class MeanConsistencyLossHetero:
    def __init__(self,
                 sigma_data=0.5,
                 flow_ratio=0.5,
                 opts=None,
                 s_min: float = 0.002,
                 lambda_term3: float = -0.3,
                 warmup_steps: int = 2000,
                 clip_grad_term: float = 5.0):
        self.sigma_data = sigma_data
        self.flow_ratio = flow_ratio
        self.opts = opts
        self.eps = 1e-6
        self.s_min = s_min
        self.lambda_term3 = float(lambda_term3)
        self.warmup_steps = int(warmup_steps)
        self.clip_grad_term = float(clip_grad_term) if clip_grad_term is not None else 0.0
        # opts defaults
        if isinstance(self.opts, dict):
            self.T_type = self.opts.get('T_type', 'baseline')
            self.T_k = float(self.opts.get('T_k', 48.0))
            self.T_eps = float(self.opts.get('T_eps', self.s_min))
            self.W_type = self.opts.get('W_type', 'constant1')
        else:
            self.T_type = 'baseline'
            self.T_k = 48.0
            self.T_eps = float(self.s_min)
            self.W_type = 'constant1'
        self.global_step = 0

    def sample_t_s(self, batch_size, device):
        import numpy as np
        samples = np.random.rand(batch_size, 2).astype(np.float32)
        t_np = np.maximum(samples[:, 0], samples[:, 1])
        s_np = np.minimum(samples[:, 0], samples[:, 1])
        num_selected = int(self.flow_ratio * batch_size)
        indices = np.random.permutation(batch_size)[:num_selected]
        s_np[indices] = t_np[indices]
        t_np = np.maximum(t_np, self.eps)
        s_np = np.maximum(s_np, self.eps)
        s_np = np.minimum(s_np, t_np)
        t = torch.tensor(t_np, device=device)
        s = torch.tensor(s_np, device=device)
        return t, s

    def _compute_T(self, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
        s_safe = torch.clamp(s, min=max(self.s_min, self.T_eps))
        r = t / s_safe
        if self.T_type == 'saturated':
            k = torch.tensor(self.T_k, device=t.device, dtype=t.dtype)
            return (k * r) / (k - 1.0 + r)
        return r

    def _compute_W(self, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
        return torch.ones_like(t)

    def compute_gradients(self, model, x_t, t, s, labels, x_direction: torch.Tensor = None):
        from torch.autograd.functional import jvp
        g_output = model(x_t, t, s, labels)
        ones_t = torch.ones_like(t)
        _, dg_dt = jvp(lambda tt: model(x_t, tt, s, labels), (t,), (ones_t,))
        jvp_x = None
        if x_direction is not None:
            _, jvp_x = jvp(lambda xx: model(xx, t, s, labels), (x_t,), (x_direction,))
        return g_output, dg_dt, jvp_x

    def __call__(self, model_dict: dict, z_dict: dict, y_dict: dict, x1_dict: dict | None = None):
        # Per-ntype loss, average across types present
        losses = []
        for ntype, z in z_dict.items():
            model = model_dict[ntype] if ntype in model_dict else None
            if model is None:
                continue
            x = z
            y = y_dict.get(ntype, torch.zeros(z.size(0), device=z.device))
            if x1_dict is None or (x1_dict is not None and ntype not in x1_dict):
                x1 = torch.randn_like(x)
            else:
                x1 = x1_dict[ntype]
            batch_size = x.shape[0]
            device = x.device
            t, s = self.sample_t_s(batch_size, device)
            s = torch.clamp(s, min=self.s_min)
            x_t = (1 - t.unsqueeze(-1)) * x + t.unsqueeze(-1) * x1
            u_student = model(x_t, t, s, y.float().unsqueeze(1))
            teacher = model
            x1_minus_x0 = x1 - x
            _g_out, dg_dt_teacher, jvp_x_teacher = self.compute_gradients(
                teacher, x_t, t, s, y.float().unsqueeze(1), x_direction=x1_minus_x0)
            s_expanded = s.unsqueeze(-1)
            t_expanded = t.unsqueeze(-1)
            t_minus_s = t_expanded - s_expanded
            term2 = (x_t * t_minus_s) / s_expanded
            grad_combined = dg_dt_teacher + (jvp_x_teacher - x1_minus_x0)
            if self.clip_grad_term and self.clip_grad_term > 0:
                gc_norm = grad_combined.norm(p=2, dim=1, keepdim=True) + 1e-12
                scale = torch.clamp(self.clip_grad_term / gc_norm, max=1.0)
                grad_combined = grad_combined * scale
            T_vec = self._compute_T(t, s)
            W_vec = self._compute_W(t, s)
            T_expanded = T_vec.unsqueeze(-1)
            term3_coeff = T_expanded * t_minus_s
            term3_teacher = term3_coeff * grad_combined
            if self.warmup_steps and self.warmup_steps > 0:
                warm = min(1.0, float(self.global_step) / float(self.warmup_steps))
            else:
                warm = 1.0
            lambda_now = self.lambda_term3 * warm
            u_target = (T_expanded * x) - term2 - lambda_now * term3_teacher
            u_target = u_target.detach()
            mse_loss_per_sample = F.mse_loss(u_student, u_target, reduction='none').mean(dim=-1)
            invT2 = (1.0 / torch.clamp(T_vec, min=1e-12)).pow(2)
            weight = invT2 * W_vec
            loss_nt = weight * mse_loss_per_sample
            losses.append(loss_nt.mean())
            self.global_step += 1
        if len(losses) == 0:
            return torch.tensor(0., device=next(iter(model_dict.values())).proj.weight.device)
        return torch.stack(losses).mean()


class MeanConsistencyMLPHetero(nn.Module):
    def __init__(self, d_in: int, dim_t: int = 512):
        super().__init__()
        self.dim_t = dim_t
        self.proj = nn.Linear(d_in, dim_t)
        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )
        self.map_time = PositionalEmbedding(num_channels=dim_t)
        self.map_s = PositionalEmbedding(num_channels=dim_t)
        self.map_label = nn.Linear(1, dim_t, bias=False)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t), nn.SiLU(), nn.Linear(dim_t, dim_t)
        )
        self.s_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t), nn.SiLU(), nn.Linear(dim_t, dim_t)
        )

    def forward(self, x, time, s, class_labels):
        time_emb = self.map_time(time)
        time_emb = self.time_embed(time_emb)
        s_emb = self.map_s(s)
        s_emb = self.s_embed(s_emb)
        x = self.proj(x) + time_emb + s_emb + self.map_label(class_labels)
        return self.mlp(x)


class MeanConsistencyHetero(nn.Module):
    """
    Heterogeneous MeanConsistency that mirrors cluster MC per node-type.
    """
    def __init__(self,
                 d_in_per_ntype: dict,
                 dim_t_per_ntype: dict,
                 sigma_data: float = 0.5,
                 flow_ratio: float = 0.5,
                 opts: dict | None = None,
                 device: str | torch.device = 'cpu'):
        super().__init__()
        self.device = torch.device(device)
        self.models = nn.ModuleDict()
        for ntype, d_in in d_in_per_ntype.items():
            dim_t = int(dim_t_per_ntype.get(ntype, max(d_in, 64)))
            self.models[ntype] = MeanConsistencyMLPHetero(d_in, dim_t)
        self.loss_fn = MeanConsistencyLossHetero(
            sigma_data=sigma_data, flow_ratio=flow_ratio, opts=opts
        )

    def forward(self, z_dict: dict, y_dict: dict, x1_dict: dict | None = None):
        return self.loss_fn(self.models, z_dict, y_dict, x1_dict)

    def sample(self, x_t_dict: dict, t, s, labels_dict: dict):
        out = {}
        for ntype, model in self.models.items():
            labels = labels_dict.get(ntype, torch.ones(x_t_dict[ntype].size(0), 1, device=x_t_dict[ntype].device))
            out[ntype] = model(x_t_dict[ntype], t, s, labels)
        return out


def _build_time_schedule(num_steps: int,
                         schedule: str = 'linear',
                         s_min: float = 0.002,
                         rho: float = 7.0,
                         device: torch.device = torch.device('cpu')):
    num_steps = int(max(1, num_steps))
    idx = torch.arange(0, num_steps + 1, device=device, dtype=torch.float32)
    if schedule == 'linear':
        tau = 1.0 - idx / float(num_steps)
    elif schedule == 'rho':
        inv_rho = 1.0 / float(rho)
        tau = (1.0 ** inv_rho + (idx / float(num_steps)) * ((float(s_min) ** inv_rho) - (1.0 ** inv_rho)))
        tau = torch.clamp(tau, min=0.0)
        tau = tau ** float(rho)
    else:
        tau = 1.0 - idx / float(num_steps)
    t_list = tau[:-1]
    s_list = tau[1:]
    s_list = torch.clamp(s_list, min=float(s_min))
    delta = torch.clamp(t_list - s_list, min=0.0)
    return t_list, s_list, delta


@torch.no_grad()
def mean_consistency_sampling_hetero(model: MeanConsistencyHetero,
                                     noise_dict: dict,
                                     labels_dict: dict,
                                     num_steps: int = 1,
                                     schedule: str = 'linear',
                                     eta: float = 0.0,
                                     s_min: float = 0.002,
                                     step_clip: float | None = None,
                                     rho: float = 7.0,
                                     heun: bool = False) -> dict:
    device = next(model.parameters()).device if any(p.requires_grad for p in model.parameters()) else torch.device('cpu')
    batch_size = next(iter(noise_dict.values())).shape[0]
    if int(num_steps) <= 1:
        t = torch.ones(batch_size, device=device)
        s = torch.zeros(batch_size, device=device)
        out = {}
        for ntype, noise in noise_dict.items():
            labels = labels_dict.get(ntype, torch.ones(noise.size(0), 1, device=noise.device))
            out[ntype] = model.models[ntype](noise, t, s, labels)
        return out
    t_seq, s_seq, delta_seq = _build_time_schedule(
        num_steps=int(num_steps), schedule=schedule, s_min=s_min, rho=rho, device=device
    )
    x_dict = {nt: noise.clone() for nt, noise in noise_dict.items()}
    for i in range(int(num_steps)):
        t_k = t_seq[i]
        s_k = s_seq[i]
        dt_k = delta_seq[i]
        T_k = t_k / torch.clamp(s_k, min=float(s_min))
        delta_hat = dt_k / torch.clamp(T_k, min=1e-12)
        t_vec = torch.full((batch_size,), float(t_k), device=device)
        s_vec = torch.full((batch_size,), float(s_k), device=device)
        for ntype, x in x_dict.items():
            labels = labels_dict.get(ntype, torch.ones(x.size(0), 1, device=x.device))
            u_k = model.models[ntype](x, t_vec, s_vec, labels)
            update = delta_hat * u_k
            if (step_clip is not None) and (step_clip > 0):
                norm = update.norm(p=2, dim=1, keepdim=True).clamp(min=1e-12)
                scale = (float(step_clip) / norm).clamp(max=1.0)
                update = update * scale
            x_euler = x - update
            x_dict[ntype] = x_euler
    return x_dict


# ============== Cluster-compat layer (single-ntype shim) ==============

class MeanConsistencyCluster(nn.Module):
    """
    Compatibility wrapper to expose a single-ntype MC API compatible with cluster version,
    implemented on top of MeanConsistencyHetero.
    """
    def __init__(self, d_in: int, dim_t: int = 512, sigma_data: float = 0.5, flow_ratio: float = 0.5, opts=None, device: str | torch.device = 'cpu'):
        super().__init__()
        self.nt = '__target__'
        self.mc = MeanConsistencyHetero(
            d_in_per_ntype={self.nt: d_in},
            dim_t_per_ntype={self.nt: dim_t},
            sigma_data=sigma_data,
            flow_ratio=flow_ratio,
            opts=opts,
            device=device
        )

    def forward(self, x: torch.Tensor, y: torch.Tensor, x1: torch.Tensor | None = None):
        z_dict = {self.nt: x}
        y_vec = y.float()
        if y_vec.dim() == 1:
            y_vec = y_vec.unsqueeze(1)
        y_dict = {self.nt: y_vec}
        x1_dict = {self.nt: x1} if x1 is not None else None
        return self.mc(z_dict, y_dict, x1_dict)

    def sample(self, x_t: torch.Tensor, t: torch.Tensor, s: torch.Tensor, labels: torch.Tensor):
        labels_vec = labels.float()
        if labels_vec.dim() == 1:
            labels_vec = labels_vec.unsqueeze(1)
        out = self.mc.models[self.nt](x_t, t, s, labels_vec)
        return out


@torch.no_grad()
def mean_consistency_sampling(model: MeanConsistencyCluster,
                              noise: torch.Tensor,
                              labels: torch.Tensor,
                              num_steps: int = 1,
                              schedule: str = 'linear',
                              eta: float = 0.0,
                              s_min: float = 0.002,
                              step_clip: float | None = None,
                              rho: float = 7.0,
                              heun: bool = False) -> torch.Tensor:
    # Reuse hetero sampler over single ntype
    nt = '__target__'
    noise_dict = {nt: noise}
    labels_vec = labels.float()
    if labels_vec.dim() == 1:
        labels_vec = labels_vec.unsqueeze(1)
    labels_dict = {nt: labels_vec}
    out_dict = mean_consistency_sampling_hetero(model.mc, noise_dict, labels_dict, num_steps, schedule, eta, s_min, step_clip, rho, heun)
    return out_dict[nt]


