from typing import Callable, Union
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch import Tensor

ModuleType = Union[str, Callable[..., nn.Module]]


class SiLU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class PositionalEmbedding(torch.nn.Module):
    """Positional Embedding Layer."""
    def __init__(self, num_channels, max_positions=10000, endpoint=False):
        super().__init__()
        self.num_channels = num_channels
        self.max_positions = max_positions
        self.endpoint = endpoint

    def forward(self, x):
        freqs = torch.arange(start=0, end=self.num_channels // 2,
                             dtype=torch.float32, device=x.device)
        freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
        freqs = (1 / self.max_positions) ** freqs
        x = x.ger(freqs.to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        return x


class MeanConsistencyLossCluster:
    """
    Mean Consistency Loss for Graph Clustering.
    Adapted for batch training scenarios in graph clustering.
    """
    
    def __init__(self,
                 sigma_data=0.5,
                 hid_dim=100,
                 flow_ratio=0.5,
                 opts=None,
                 s_min: float = 0.002,
                 lambda_term3: float =  -0.3,
                 warmup_steps: int = 2000,
                 clip_grad_term: float = 5.0,
                 use_ema_teacher: bool = False,
                 ema_decay: float = 0.999):
        self.sigma_data = sigma_data
        self.hid_dim = hid_dim
        self.flow_ratio = flow_ratio
        self.opts = opts
        self.eps = 1e-6
        self.s_min = s_min
        self.lambda_term3 = float(lambda_term3)
        self.warmup_steps = int(warmup_steps)
        self.clip_grad_term = float(clip_grad_term) if clip_grad_term is not None else 0.0
        self.use_ema_teacher = use_ema_teacher
        self.ema_decay = ema_decay
        self.global_step = 0
        self.teacher_model = None

        if isinstance(self.opts, dict):
            self.T_type = self.opts.get('T_type', 'baseline')
            self.T_k = float(self.opts.get('T_k', 48.0))
            self.T_eps = float(self.opts.get('T_eps', self.s_min))
            self.W_type = self.opts.get('W_type', 'constant1')
        else:
            self.T_type = 'baseline'
            self.T_k = 48.0
            self.T_eps = float(self.s_min)
            self.W_type = 'constant1'

    def sample_t_s(self, batch_size, device):
        """Sample time steps t and s."""
        import numpy as np
        
        samples = np.random.rand(batch_size, 2).astype(np.float32)
        
        t_np = np.maximum(samples[:, 0], samples[:, 1])
        s_np = np.minimum(samples[:, 0], samples[:, 1])
        
        num_selected = int(self.flow_ratio * batch_size)
        indices = np.random.permutation(batch_size)[:num_selected]
        s_np[indices] = t_np[indices]
        
        t_np = np.maximum(t_np, self.eps)
        s_np = np.maximum(s_np, self.eps)
        
        s_np = np.minimum(s_np, t_np)
        
        t = torch.tensor(t_np, device=device)
        s = torch.tensor(s_np, device=device)
        
        return t, s

    def _compute_T(self, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
        """Compute stabilized T(t,s)."""
        s_safe = torch.clamp(s, min=max(self.s_min, self.T_eps))
        r = t / s_safe
        if self.T_type == 'saturated':
            k = torch.tensor(self.T_k, device=t.device, dtype=t.dtype)
            return (k * r) / (k - 1.0 + r)
        return r

    def _compute_W(self, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
        """Compute weight W(t). Default constant 1."""
        if self.W_type == 'constant1':
            return torch.ones_like(t)
        return torch.ones_like(t)

    def compute_gradients(self, model, x_t, t, s, labels, x_direction: torch.Tensor = None):
        """
        Compute gradients using JVP:
        - dg_dt = ∂g/∂t
        - jvp_x = (∂g/∂x) · v
        """
        from torch.autograd.functional import jvp

        g_output = model(x_t, t, s, labels)

        ones_t = torch.ones_like(t)
        _, dg_dt = jvp(lambda tt: model(x_t, tt, s, labels), (t,), (ones_t,))

        jvp_x = None
        if x_direction is not None:
            _, jvp_x = jvp(lambda xx: model(xx, t, s, labels), (x_t,), (x_direction,))

        return g_output, dg_dt, jvp_x

    def __call__(self, model, data, labels, x1_data=None, teacher_model=None):
        """Compute Mean Consistency Loss."""
        batch_size = data.shape[0]
        device = data.device
        
        if x1_data is None:
            x1_data = torch.randn_like(data)
        
        t, s = self.sample_t_s(batch_size, device)
        s = torch.clamp(s, min=self.s_min)
        
        x_t = (1 - t.unsqueeze(-1)) * data + t.unsqueeze(-1) * x1_data
        
        u_student = model(x_t, t, s, labels)
        
        teacher = teacher_model or self.teacher_model or model
        x1_minus_x0 = x1_data - data
        _g_out, dg_dt_teacher, jvp_x_teacher = self.compute_gradients(
            teacher, x_t, t, s, labels, x_direction=x1_minus_x0)
        
        s_expanded = s.unsqueeze(-1)
        t_expanded = t.unsqueeze(-1)
        s_over_t = s_expanded / t_expanded
        t_over_s = t_expanded / s_expanded
        t_minus_s = t_expanded - s_expanded
        
        term2 = (x_t * t_minus_s) / s_expanded
        
        grad_t_term = dg_dt_teacher
        
        grad_x_term = jvp_x_teacher - x1_minus_x0
        
        grad_combined = grad_t_term + grad_x_term
        if self.clip_grad_term and self.clip_grad_term > 0:
            gc_norm = grad_combined.norm(p=2, dim=1, keepdim=True) + 1e-12
            scale = torch.clamp(self.clip_grad_term / gc_norm, max=1.0)
            grad_combined = grad_combined * scale
        
        T_vec = self._compute_T(t, s)
        W_vec = self._compute_W(t, s)
        T_expanded = T_vec.unsqueeze(-1)

        term3_coeff = T_expanded * t_minus_s
        term3_teacher = term3_coeff * grad_combined
        
        if self.warmup_steps and self.warmup_steps > 0:
            warm = min(1.0, float(self.global_step) / float(self.warmup_steps))
        else:
            warm = 1.0
        lambda_now = self.lambda_term3 * warm
        
        u_target = (T_expanded * data) - term2 - lambda_now * term3_teacher
        u_target = u_target.detach()
        
        mse_loss_per_sample = F.mse_loss(u_student, u_target, reduction='none').mean(dim=-1)
        
        invT2 = (1.0 / torch.clamp(T_vec, min=1e-12)).pow(2)
        weight = invT2 * W_vec
        loss = weight * mse_loss_per_sample
        
        self.global_step += 1
        return loss


class MeanConsistencyMLPCluster(nn.Module):
    """
    MLP model for Graph Clustering version, supporting Mean Consistency loss with parameter s.
    """
    def __init__(self, d_in, dim_t=512):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_time = PositionalEmbedding(num_channels=dim_t)
        self.map_s = PositionalEmbedding(num_channels=dim_t)
        self.map_label = nn.Linear(1, dim_t, bias=False)

        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
        
        self.s_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )

    def forward(self, x, time, s, class_labels):
        """
        Args:
            x: Input features x_t
            time: Time step t
            s: Time parameter s
            class_labels: Class labels
        """
        time_emb = self.map_time(time)
        time_emb = self.time_embed(time_emb)
        
        s_emb = self.map_s(s)
        s_emb = self.s_embed(s_emb)

        x = self.proj(x) + time_emb + s_emb + self.map_label(class_labels)
        
        return self.mlp(x)


class MeanConsistencyCluster(nn.Module):
    """
    Mean Consistency Model for Graph Clustering.
    Supports batch training for graph clustering.
    """
    def __init__(self, d_in, dim_t=512, sigma_data=0.5, flow_ratio=0.5, opts=None, device='cuda'):
        super().__init__()
        
        self.model = MeanConsistencyMLPCluster(d_in, dim_t)
        
        self.loss_fn = MeanConsistencyLossCluster(sigma_data, d_in, flow_ratio, opts)
        
        self.last_gen_time = 0.0
        self.device = device

    def forward(self, x, y, x1=None):
        """
        Compute loss.
        
        Args:
            x: Input data x_0 (Latent representation from VGAE)
            y: Labels
            x1: Pure noise data x_1 (Optional, automatically generated if None)
        """
        loss = self.loss_fn(self.model, x, y, x1)
        return loss.mean()
    
    def sample(self, x_t, t, s, labels):
        """
        Sampling function, calls the model directly.
        
        Args:
            x_t: Noise input
            t: Time step
            s: Parameter s
            labels: Labels
        """
        return self.model(x_t, t, s, labels)


def _build_time_schedule(num_steps: int,
                         schedule: str = 'linear',
                         s_min: float = 0.002,
                         rho: float = 7.0,
                         device: torch.device = torch.device('cpu')):
    """
    Generate multi-step time schedule.
    """
    num_steps = int(max(1, num_steps))
    idx = torch.arange(0, num_steps + 1, device=device, dtype=torch.float32)
    if schedule == 'linear':
        tau = 1.0 - idx / float(num_steps)
    elif schedule == 'rho':
        inv_rho = 1.0 / float(rho)
        tau = (1.0 ** inv_rho + (idx / float(num_steps)) * ((float(s_min) ** inv_rho) - (1.0 ** inv_rho)))
        tau = torch.clamp(tau, min=0.0)
        tau = tau ** float(rho)
    else:
        tau = 1.0 - idx / float(num_steps)
    t_list = tau[:-1]
    s_list = tau[1:]
    s_list = torch.clamp(s_list, min=float(s_min))
    delta = torch.clamp(t_list - s_list, min=0.0)
    return t_list, s_list, delta


def mean_consistency_sampling(model,
                              noise,
                              labels,
                              num_steps: int = 1,
                              schedule: str = 'linear',
                              eta: float = 0.0,
                              s_min: float = 0.002,
                              single_use_s_min: bool = False,
                              step_clip: float = None,
                              rho: float = 7.0,
                              heun: bool = False):
    """
    Multi-step sampling for clustering version.
    """
    device = noise.device
    batch_size = noise.shape[0]

    if int(num_steps) <= 1:
        t = torch.ones(batch_size, device=device)
        s = (torch.full((batch_size,), float(s_min), device=device)
             if bool(single_use_s_min) else torch.zeros(batch_size, device=device))
        with torch.no_grad():
            x = model.sample(noise, t, s, labels)
        return x

    with torch.no_grad():
        t_seq, s_seq, delta_seq = _build_time_schedule(
            num_steps=int(num_steps), schedule=schedule, s_min=s_min, rho=rho, device=device
        )
        x = noise

        for i in range(int(num_steps)):
            t_k = t_seq[i]
            s_k = s_seq[i]
            dt_k = delta_seq[i]

            T_k = t_k / torch.clamp(s_k, min=float(s_min))
            delta_hat = dt_k / torch.clamp(T_k, min=1e-12)

            t_vec = torch.full((batch_size,), float(t_k), device=device)
            s_vec = torch.full((batch_size,), float(s_k), device=device)

            u_k = model.sample(x, t_vec, s_vec, labels)
            update = delta_hat * u_k
            if (step_clip is not None) and (step_clip > 0):
                norm = update.norm(p=2, dim=1, keepdim=True).clamp(min=1e-12)
                scale = (float(step_clip) / norm).clamp(max=1.0)
                update = update * scale

            x_euler = x - update

            if heun and (i < int(num_steps) - 1):
                t_next = t_seq[i + 1]
                s_next = s_seq[i + 1]
                t_vec_next = torch.full((batch_size,), float(t_next), device=device)
                s_vec_next = torch.full((batch_size,), float(s_next), device=device)
                u_next = model.sample(x_euler, t_vec_next, s_vec_next, labels)
                update_heun = delta_hat * 0.5 * (u_k + u_next)
                if (step_clip is not None) and (step_clip > 0):
                    norm2 = update_heun.norm(p=2, dim=1, keepdim=True).clamp(min=1e-12)
                    scale2 = (float(step_clip) / norm2).clamp(max=1.0)
                    update_heun = update_heun * scale2
                x = x - update_heun
            else:
                x = x_euler

            if eta and (eta > 0.0):
                x = x + float(eta) * torch.randn_like(x)

        return x
