from typing import Callable, Union
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch import Tensor

ModuleType = Union[str, Callable[..., nn.Module]]


class SiLU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class ConsistencyLoss:
    def __init__(self, sigma_data=0.5, hid_dim=100, opts=None):
        self.sigma_data = sigma_data
        self.hid_dim = hid_dim
        self.opts = opts

    def __call__(self, model, model_ema, data, labels, time_steps):
        
        batch_size = data.shape[0]
        
        
        t_n = torch.rand(batch_size, device=data.device)  
        
        
        z = torch.randn_like(data)
        
        
        x_t = data + torch.sqrt(t_n.unsqueeze(-1)) * z  
        
        
        pred = model(x_t, t_n, labels)
        
        
        target = data
        
        
        loss = F.mse_loss(pred, target, reduction='none')
        
        return loss


class PositionalEmbedding(torch.nn.Module):
    def __init__(self, num_channels, max_positions=10000, endpoint=False):
        super().__init__()
        self.num_channels = num_channels
        self.max_positions = max_positions
        self.endpoint = endpoint

    def forward(self, x):
        freqs = torch.arange(start=0, end=self.num_channels // 2,
                             dtype=torch.float32, device=x.device)
        freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
        freqs = (1 / self.max_positions) ** freqs
        x = x.ger(freqs.to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        return x


class MLPConsistency(nn.Module):
    
    def __init__(self, d_in, dim_t=512):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_time = PositionalEmbedding(num_channels=dim_t)
        self.map_label = nn.Linear(1, dim_t, bias=False)

        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )

    def forward(self, x, time, class_labels):
        
        emb = self.map_time(time)
        emb = self.time_embed(emb)

        
        x = self.proj(x) + emb + self.map_label(class_labels)
        
        
        return self.mlp(x)


class ConsistencyModel(nn.Module):
    
    def __init__(self, denoise_fn, hid_dim, sigma_data=0.5, opts=None):
        super().__init__()
        
        
        self.model = denoise_fn
        
        
        self.model_ema = self._create_ema_model(denoise_fn)
        
       
        self.ema_decay = 0.999
        
        
        self.loss_fn = ConsistencyLoss(sigma_data, hid_dim, opts)
        
        
        for param in self.model_ema.parameters():
            param.requires_grad_(False)

    def _create_ema_model(self, denoise_fn):
        
        ema_model = type(denoise_fn)(
            d_in=denoise_fn.mlp[0].in_features,
            dim_t=denoise_fn.dim_t
        )
        
        
        for ema_param, param in zip(ema_model.parameters(), denoise_fn.parameters()):
            ema_param.data.copy_(param.data)
            
        return ema_model

    def forward(self, x, y):
        
        time_steps = torch.linspace(0, 1, 100, device=x.device)
        
        
        loss = self.loss_fn(self.model, self.model_ema, x, y, time_steps)
        return loss.mean(-1).mean()
    
    def update_ema(self):
        
        with torch.no_grad():
            for ema_param, param in zip(self.model_ema.parameters(), self.model.parameters()):
                ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay)


def _build_time_schedule(num_steps: int,
                         schedule: str = 'linear',
                         s_min: float = 0.002,
                         rho: float = 7.0,
                         device: torch.device = torch.device('cpu')):
    
    num_steps = int(max(1, num_steps))
    idx = torch.arange(0, num_steps + 1, device=device, dtype=torch.float32)
    if schedule == 'linear':
        tau = 1.0 - idx / float(num_steps)
    elif schedule == 'rho':
        inv_rho = 1.0 / float(rho)
        tau = (1.0 ** inv_rho + (idx / float(num_steps)) * ((float(s_min) ** inv_rho) - (1.0 ** inv_rho)))
        tau = torch.clamp(tau, min=0.0)
        tau = tau ** float(rho)
    else:
        tau = 1.0 - idx / float(num_steps)
    t_list = tau[:-1]
    s_list = tau[1:]
    s_list = torch.clamp(s_list, min=float(s_min))
    delta = torch.clamp(t_list - s_list, min=0.0)
    return t_list, s_list, delta


def multistep_consistency_sampling(model,
                                   noise,
                                   label,
                                   num_steps: int = 1,
                                   schedule: str = 'linear',
                                   eta: float = 0.0,
                                   s_min: float = 0.002,
                                   single_use_s_min: bool = False,
                                   step_clip: float = None,
                                   rho: float = 7.0,
                                   heun: bool = False):
    
    device = noise.device
    batch_size = noise.shape[0]

    
    if int(num_steps) <= 1:
        t = torch.ones(batch_size, device=device)
        t_end = (torch.full((batch_size,), float(s_min), device=device)
                 if bool(single_use_s_min) else torch.zeros(batch_size, device=device))
        
        with torch.no_grad():
            x0 = model(noise, t, label)
            
            x = model(x0, t_end, label)
        return x

    with torch.no_grad():
        t_seq, s_seq, delta_seq = _build_time_schedule(
            num_steps=int(num_steps), schedule=schedule, s_min=s_min, rho=rho, device=device
        )
        x = noise

        for i in range(int(num_steps)):
            t_k = t_seq[i]
            s_k = s_seq[i]
            dt_k = delta_seq[i]

            t_vec = torch.full((batch_size,), float(t_k), device=device)
            
            u_k = model(x, t_vec, label)

            update = dt_k * u_k
            if (step_clip is not None) and (step_clip > 0):
                norm = update.norm(p=2, dim=1, keepdim=True).clamp(min=1e-12)
                scale = (float(step_clip) / norm).clamp(max=1.0)
                update = update * scale

            x_euler = x - update

            if heun and (i < int(num_steps) - 1):
                t_next = t_seq[i + 1]
                t_vec_next = torch.full((batch_size,), float(t_next), device=device)
                u_next = model(x_euler, t_vec_next, label)
                update_heun = dt_k * 0.5 * (u_k + u_next)
                if (step_clip is not None) and (step_clip > 0):
                    norm2 = update_heun.norm(p=2, dim=1, keepdim=True).clamp(min=1e-12)
                    scale2 = (float(step_clip) / norm2).clamp(max=1.0)
                    update_heun = update_heun * scale2
                x = x - update_heun
            else:
                x = x_euler

            if eta and (eta > 0.0):
                x = x + float(eta) * torch.randn_like(x)

        return x


def consistency_sampling(model,
                        noise,
                        label,
                        num_steps: int = 1,
                        schedule: str = 'linear',
                        eta: float = 0.0,
                        s_min: float = 0.002,
                        single_use_s_min: bool = False,
                        step_clip: float = None,
                        rho: float = 7.0,
                        heun: bool = False):
    
    return multistep_consistency_sampling(
        model=model,
        noise=noise,
        label=label,
        num_steps=num_steps,
        schedule=schedule,
        eta=eta,
        s_min=s_min,
        single_use_s_min=single_use_s_min,
        step_clip=step_clip,
        rho=rho,
        heun=heun
    )


class ConsistencyLossCluster:
    
    def __init__(self, sigma_data=0.5, hid_dim=100, opts=None):
        self.sigma_data = sigma_data
        self.hid_dim = hid_dim
        self.opts = opts

    def __call__(self, model, model_ema, data, labels, time_steps):
        
        batch_size = data.shape[0]
        
        
        t_n = torch.rand(batch_size, device=data.device)
        
        
        z = torch.randn_like(data)
        
        
        x_t = data + torch.sqrt(t_n.unsqueeze(-1)) * z
        
        
        pred = model(x_t, t_n, labels)
        
        
        target = data
        
       
        loss = F.mse_loss(pred, target, reduction='none')
        
        return loss


class ConsistencyCluster(nn.Module):

    def __init__(self, d_in, dim_t=512, sigma_data=0.5, opts=None, device=None):
        super().__init__()
        
        
        self.model = MLPConsistency(d_in, dim_t)
        
        
        self.model_ema = self._create_ema_model(self.model)
        
        
        self.ema_decay = 0.999
        
        
        self.loss_fn = ConsistencyLossCluster(sigma_data, d_in, opts)
        
        
        self.device = device
        
        
        for param in self.model_ema.parameters():
            param.requires_grad_(False)

    def _create_ema_model(self, denoise_fn):
        
        ema_model = MLPConsistency(
            d_in=denoise_fn.proj.in_features,
            dim_t=denoise_fn.dim_t
        )
        
        
        for ema_param, param in zip(ema_model.parameters(), denoise_fn.parameters()):
            ema_param.data.copy_(param.data)
            
        return ema_model

    def forward(self, x, y):
        
        time_steps = torch.linspace(0, 1, 100, device=x.device)
        
        
        loss = self.loss_fn(self.model, self.model_ema, x, y, time_steps)
        return loss.mean(-1).mean()
    
    def update_ema(self):
        
        with torch.no_grad():
            for ema_param, param in zip(self.model_ema.parameters(), self.model.parameters()):
                ema_param.data.mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay)