from typing import Callable, Union
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch import Tensor

ModuleType = Union[str, Callable[..., nn.Module]]


class SiLU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class MeanConsistencyLoss:
    """
    Implementation of Mean Consistency Loss Function.

    Loss Function Paradigm (Student-Teacher):
    Let u = g_theta(x_t, t, s).
    Construct Teacher Target:
    u_tgt = T(t) * x_0 - (x_t * (t - s)) / s - T(t) * (t - s) * G(x_t, t, s)
    Where G = d(g_theta)/dt + (d(g_theta)/dx_t - I)(x_1 - x_0).
    Training uses:
    L(theta) = E[(1/T(t))^2 * W(t) * || u - stopgrad(u_tgt) ||^2]

    Convention: W(t) defaults to 1; T(t) is a stabilization factor, replacing (t/s) to suppress numerical fluctuations.

    Where:
    - t, s: Time parameters, using MeanFlow time sampling strategy.
      * For flow_ratio proportion of samples: s = t
      * For remaining samples: s < t
      * Always maintain t >= s
    - x_0 ~ p(x_0): Original data distribution.
    - x_1 ~ N(0, I): Pure noise (Standard Normal Distribution).
    - x_t = (1 - t)x_0 + t*x_1: Noisy data constructed using MeanFlow noise formula.
    - g_theta(x_t, t, s): Parameterized model, dependent on input x_t, time t, and parameter s.
    - The loss function averages the consistency function over the interval [s, t].
    """
    
    def __init__(self,
                 sigma_data=0.5,
                 hid_dim=100,
                 flow_ratio=0.5,
                 opts=None,
                 s_min: float = 0.002,
                 lambda_term3: float = -0.3,
                 warmup_steps: int = 2000,
                 clip_grad_term: float = 5.0,
                 use_ema_teacher: bool = False,
                 ema_decay: float = 0.999):
        self.sigma_data = sigma_data
        self.hid_dim = hid_dim
        self.flow_ratio = flow_ratio
        self.opts = opts
        self.eps = 1e-6
        self.s_min = s_min
        self.lambda_term3 = float(lambda_term3)
        self.warmup_steps = int(warmup_steps)
        self.clip_grad_term = float(clip_grad_term) if clip_grad_term is not None else 0.0
        self.use_ema_teacher = use_ema_teacher
        self.ema_decay = ema_decay
        self.global_step = 0
        self.teacher_model = None

    def sample_t_s(self, batch_size, device):
        """
        Sample time t and s, referencing MeanFlow time sampling strategy.
        
        Args:
            batch_size: Batch size
            device: Device
            
        Returns:
            t, s: Time tensors, ensuring t >= s
        """
        import numpy as np
        
        samples = np.random.rand(batch_size, 2).astype(np.float32)
        
        t_np = np.maximum(samples[:, 0], samples[:, 1])
        s_np = np.minimum(samples[:, 0], samples[:, 1])
        
        num_selected = int(self.flow_ratio * batch_size)
        indices = np.random.permutation(batch_size)[:num_selected]
        s_np[indices] = t_np[indices]
        
        t_np = np.maximum(t_np, self.eps)
        s_np = np.maximum(s_np, self.eps)
        
        s_np = np.minimum(s_np, t_np)
        
        t = torch.tensor(t_np, device=device)
        s = torch.tensor(s_np, device=device)
        
        return t, s

    def compute_gradients(self, model, x_t, t, s, labels, create_graph: bool = True):
        """
        Compute partial derivatives of model output with respect to time t and input x_t.
        
        Args:
            model: Model g_theta(x_t, t, s)
            x_t: Noisy input
            t: Time step
            s: Parameter s
            labels: Labels
            
        Returns:
            g_output: Model output g_theta(x_t, t, s)
            dg_dt: d(g_theta)/dt
            dg_dx: d(g_theta)/dx_t
        """
        x_t_grad = x_t.clone().detach().requires_grad_(True)
        t_grad = t.clone().detach().requires_grad_(True)
        
        g_output = model(x_t_grad, t_grad, s, labels)
        
        batch_size = x_t.shape[0]
        feature_dim = x_t.shape[1]
        
        dg_dt = torch.zeros_like(g_output)
        for i in range(feature_dim):
            grad_outputs = torch.zeros_like(g_output)
            grad_outputs[:, i] = 1.0
            
            grad_t = torch.autograd.grad(
                outputs=g_output,
                inputs=t_grad,
                grad_outputs=grad_outputs,
                create_graph=create_graph,
                retain_graph=True,
                only_inputs=True
            )[0]
            
            dg_dt[:, i] = grad_t
        
        dg_dx = torch.zeros(batch_size, feature_dim, feature_dim, device=x_t.device)
        for i in range(feature_dim):
            grad_outputs = torch.zeros_like(g_output)
            grad_outputs[:, i] = 1.0
            
            grad_x = torch.autograd.grad(
                outputs=g_output,
                inputs=x_t_grad,
                grad_outputs=grad_outputs,
                create_graph=create_graph,
                retain_graph=True,
                only_inputs=True
            )[0]
            
            dg_dx[:, i, :] = grad_x
        
        return g_output, dg_dt, dg_dx

    def __call__(self, model, data, labels, x1_data=None, teacher_model=None):
        """
        Compute Mean Consistency Loss.
        
        Args:
            model: Model g_theta(x_t, t, s)
            data: Original data x_0
            labels: Data labels
            x1_data: Pure noise data x_1 (Automatically generated standard normal noise if None)
            
        Returns:
            loss: Computed loss values
        """
        batch_size = data.shape[0]
        device = data.device
        
        if x1_data is None:
            x1_data = torch.randn_like(data)
        
        t, s = self.sample_t_s(batch_size, device)
        s = torch.clamp(s, min=self.s_min)
        
        x_t = (1 - t.unsqueeze(-1)) * data + t.unsqueeze(-1) * x1_data
        
        u_student = model(x_t, t, s, labels)
        
        teacher = teacher_model or self.teacher_model or model
        _, dg_dt_teacher, dg_dx_teacher = self.compute_gradients(
            teacher, x_t, t, s, labels, create_graph=False)
        
        s_expanded = s.unsqueeze(-1)
        t_expanded = t.unsqueeze(-1)
        s_over_t = s_expanded / t_expanded
        t_over_s = t_expanded / s_expanded
        t_minus_s = t_expanded - s_expanded
        
        term2 = (x_t * t_minus_s) / s_expanded
        
        grad_t_term = dg_dt_teacher
        
        x1_minus_x0 = x1_data - data
        
        identity = torch.eye(data.shape[1], device=device).unsqueeze(0).repeat(batch_size, 1, 1)
        dg_dx_minus_I = dg_dx_teacher - identity
        
        grad_x_term = torch.bmm(dg_dx_minus_I, x1_minus_x0.unsqueeze(-1)).squeeze(-1)
        
        grad_combined = grad_t_term + grad_x_term
        if self.clip_grad_term and self.clip_grad_term > 0:
            gc_norm = grad_combined.norm(p=2, dim=1, keepdim=True) + 1e-12
            scale = torch.clamp(self.clip_grad_term / gc_norm, max=1.0)
            grad_combined = grad_combined * scale
        
        t_squared_over_s = (t_expanded * t_expanded) / s_expanded
        one_minus_s_over_t = 1 - s_over_t
        term3_coeff = t_squared_over_s * one_minus_s_over_t
        
        term3_teacher = term3_coeff * grad_combined
        
        if self.warmup_steps and self.warmup_steps > 0:
            warm = min(1.0, float(self.global_step) / float(self.warmup_steps))
        else:
            warm = 1.0
        lambda_now = self.lambda_term3 * warm
        
        u_target = (t_over_s * data) - term2 - lambda_now * term3_teacher
        u_target = u_target.detach()
        
        mse_loss_per_sample = F.mse_loss(u_student, u_target, reduction='none').mean(dim=-1)
        
        s_over_t_weight = (s / t).pow(2)
        loss = s_over_t_weight * mse_loss_per_sample
        
        self.global_step += 1
        return loss


class PositionalEmbedding(torch.nn.Module):
    """Positional Embedding Layer."""
    def __init__(self, num_channels, max_positions=10000, endpoint=False):
        super().__init__()
        self.num_channels = num_channels
        self.max_positions = max_positions
        self.endpoint = endpoint

    def forward(self, x):
        half_channels = int(math.ceil(self.num_channels / 2))
        denom = max(1, half_channels - (1 if self.endpoint else 0))
        freqs = torch.arange(start=0, end=half_channels,
                             dtype=torch.float32, device=x.device)
        freqs = freqs / denom
        freqs = (1 / self.max_positions) ** freqs
        emb = x.ger(freqs.to(x.dtype))
        emb = torch.cat([emb.cos(), emb.sin()], dim=1)
        if emb.size(1) > self.num_channels:
            emb = emb[:, :self.num_channels]
        return emb


class MeanConsistencyMLP(nn.Module):
    """
    MLP model supporting parameter s, used for Mean Consistency Loss.
    Implements g_theta(x_t, t, s).
    """
    def __init__(self, d_in, dim_t=512):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_time = PositionalEmbedding(num_channels=dim_t)
        self.map_s = PositionalEmbedding(num_channels=dim_t)
        self.map_label = nn.Linear(1, dim_t, bias=False)

        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
        
        self.s_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )

    def forward(self, x, time, s, class_labels):
        """
        x: Input features x_t
        time: Time step t
        s: Time parameter s (dynamically sampled)
        class_labels: Class labels
        """
        time_emb = self.map_time(time)
        time_emb = self.time_embed(time_emb)
        
        s_emb = self.map_s(s)
        s_emb = self.s_embed(s_emb)

        x = self.proj(x) + time_emb + s_emb + self.map_label(class_labels)
        
        return self.mlp(x)


class MeanConsistencyModel(nn.Module):
    """
    Complete model based on Mean Consistency.
    """
    def __init__(self, d_in, dim_t=512, sigma_data=0.5, flow_ratio=0.5, opts=None):
        super().__init__()
        
        self.model = MeanConsistencyMLP(d_in, dim_t)
        
        self.loss_fn = MeanConsistencyLoss(sigma_data, d_in, flow_ratio, opts)

    def forward(self, x, y, x1=None):
        """
        Compute loss.
        
        Args:
            x: Input data x_0
            y: Labels
            x1: Pure noise data x_1 (Optional, automatically generated if None)
        """
        loss = self.loss_fn(self.model, x, y, x1)
        return loss.mean()
    
    def sample(self, x_t, t, s, labels):
        """
        Sampling function, directly calls model.
        
        Args:
            x_t: Noisy input
            t: Time step
            s: Parameter s
            labels: Labels
        """
        return self.model(x_t, t, s, labels)


def _build_time_schedule(num_steps: int,
                         schedule: str = 'linear',
                         s_min: float = 0.002,
                         rho: float = 7.0,
                         device: torch.device = torch.device('cpu')):
    """
    Constructs time pairs (t_k, s_k) of length K, and step intervals delta_k = t_k - s_k.
    Convention uses K+1 monotonically decreasing anchor points tau_i (i=0..K), where t_k=tau_{k-1}, s_k=tau_k.
    - linear:    tau_i = 1 - i/K
    - cosine:    tau_i = cos^2(pi/2 * i/K)
    - rho(Karras): tau_i = ((1)^(1/rho) + (i/K) * ((s_min)^(1/rho) - (1)^(1/rho)))^rho
      Note: rho schedule does not reach 0 at the end, but s_min, for numerical stability (consistent with clamp).
    """
    num_steps = int(max(1, num_steps))
    idx = torch.arange(0, num_steps + 1, device=device, dtype=torch.float32)
    if schedule == 'linear':
        tau = 1.0 - idx / float(num_steps)
    elif schedule == 'rho':
        inv_rho = 1.0 / float(rho)
        tau = (1.0 ** inv_rho + (idx / float(num_steps)) * ((float(s_min) ** inv_rho) - (1.0 ** inv_rho)))
        tau = torch.clamp(tau, min=0.0)
        tau = tau ** float(rho)
    else:
        tau = 1.0 - idx / float(num_steps)

    t_list = tau[:-1]
    s_list = tau[1:]
    s_list = torch.clamp(s_list, min=float(s_min))
    delta = torch.clamp(t_list - s_list, min=0.0)
    return t_list, s_list, delta


def mean_consistency_sampling(model,
                              noise,
                              labels,
                              num_steps: int = 1,
                              schedule: str = 'linear',
                              eta: float = 0.0,
                              s_min: float = 0.002,
                              step_clip: float = None,
                              rho: float = 7.0,
                              heun: bool = False):
    
    device = noise.device
    batch_size = noise.shape[0]

    if int(num_steps) <= 1:
        t = torch.ones(batch_size, device=device)
        s = torch.zeros(batch_size, device=device)
        with torch.no_grad():
            x = model.sample(noise, t, s, labels)
        return x

    with torch.no_grad():
        t_seq, s_seq, delta_seq = _build_time_schedule(
            num_steps=int(num_steps), schedule=schedule, s_min=s_min, rho=rho, device=device
        )
        x = noise

        for i in range(int(num_steps)):
            t_k = t_seq[i]
            s_k = s_seq[i]
            dt_k = delta_seq[i]

            T_k = t_k / torch.clamp(s_k, min=float(s_min))
            delta_hat = dt_k / torch.clamp(T_k, min=1e-12)

            t_vec = torch.full((batch_size,), float(t_k), device=device)
            s_vec = torch.full((batch_size,), float(s_k), device=device)

            u_k = model.sample(x, t_vec, s_vec, labels)

            update = delta_hat * u_k
            if (step_clip is not None) and (step_clip > 0):
                norm = update.norm(p=2, dim=1, keepdim=True).clamp(min=1e-12)
                scale = (float(step_clip) / norm).clamp(max=1.0)
                update = update * scale

            x_euler = x - update

            if heun and (i < int(num_steps) - 1):
                t_next = t_seq[i + 1]
                s_next = s_seq[i + 1]
                t_vec_next = torch.full((batch_size,), float(t_next), device=device)
                s_vec_next = torch.full((batch_size,), float(s_next), device=device)
                u_next = model.sample(x_euler, t_vec_next, s_vec_next, labels)
                update_heun = delta_hat * 0.5 * (u_k + u_next)
                if (step_clip is not None) and (step_clip > 0):
                    norm2 = update_heun.norm(p=2, dim=1, keepdim=True).clamp(min=1e-12)
                    scale2 = (float(step_clip) / norm2).clamp(max=1.0)
                    update_heun = update_heun * scale2
                x = x - update_heun
            else:
                x = x_euler

            if eta and (eta > 0.0):
                x = x + float(eta) * torch.randn_like(x)

        return x
