import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import math

from diffusion_DiSA import create_diffusion

import torch.fft

def pivotal_token_selection(tokens: torch.Tensor) -> torch.Tensor:
    """
    基于FreqTS的频率感知token重要性评分

    参数:
        tokens: 输入token特征 [num_tokens, dim]

    返回:
        scores: 每个token的重要性分数 [num_tokens]
    """
    # 1. 高效频率变换 (使用实值FFT优化计算)
    freq_tokens = torch.fft.rfft(tokens, dim=1, norm='ortho')

    # 2. 计算振幅谱 (跳过直流分量)
    amplitude_spectrum = torch.abs(freq_tokens[:, 1:])  # [num_tokens, dim//2+1]

    # 3. 重要性分数计算 (高频能量加权)
    # 创建频率权重向量：线性递增权重强调高频分量
    freq_weights = torch.linspace(1, 2, amplitude_spectrum.size(1),
                                  device=tokens.device)  # [dim//2+1]

    # 加权振幅积分 (点乘代替循环)
    scores = torch.einsum('ij,j->i', amplitude_spectrum, freq_weights)  # [num_tokens]

    return scores

def pivotal_token_selection_(tokens: torch.Tensor) -> torch.Tensor:
    """
    计算每个token的重要性分数（基于高频特征）
    
    参数:
        tokens: 输入token特征 [num_tokens, dim]
        
    返回:
        scores: 每个token的重要性分数 [num_tokens]
    """
    # 1. 计算全局低频分量 (空间域直流分量)
    global_avg = torch.mean(tokens, dim=0, keepdim=True)  # [1, dim]
    
    # 2. 计算每个token的高频分量
    high_freq_component = tokens - global_avg  # [num_tokens, dim]
    
    # 3. 计算重要性分数 (高频特征的L2范数)
    scores = torch.norm(high_freq_component, p=2, dim=1)  # [num_tokens]
    
    return scores


def random_token_selection(tokens: torch.Tensor) -> torch.Tensor:
    """
    随机生成token的重要性分数（用于对比实验）
    
    参数:
        tokens: 输入token特征 [num_tokens, dim]
        
    返回:
        scores: 随机生成的重要性分数 [num_tokens]
    """
    num_tokens = tokens.shape[0]
    device = tokens.device
    
    # 生成随机分数
    scores = torch.rand(num_tokens, device=device)
    
    return scores



class DiffLoss(nn.Module):
    """Diffusion Loss"""
    def __init__(self, target_channels, z_channels, depth, width, grad_checkpointing=False,
                 diff_upper_steps=50, diff_lower_steps=5, diff_annealing_strategy="linear", diff_sampler="default",
                 pivot_step_threshold=15, pivot_diffusion_steps=50, token_selection_strategy="pivotal",
                 pivot_token_percentage=0.1):
        super(DiffLoss, self).__init__()
        self.in_channels = target_channels
        self.net = SimpleMLPAdaLN(
            in_channels=target_channels,
            model_channels=width,
            out_channels=target_channels * 2,  # for vlb loss
            z_channels=z_channels,
            num_res_blocks=depth,
            grad_checkpointing=grad_checkpointing
        )

        self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
        self.sampler = diff_sampler
        print(f"diff_lower_steps: {diff_lower_steps}, diff_upper_steps: {diff_upper_steps}")
        self.gen_diffusion = [
            create_diffusion(timestep_respacing=str(step), noise_schedule="cosine")
            for step in range(diff_lower_steps, diff_upper_steps+1, 1)
        ]
        self.diff_annealing_strategy = diff_annealing_strategy
        # 新增参数用于混合策略
        self.pivot_step_threshold = pivot_step_threshold
        self.pivot_diffusion_steps = pivot_diffusion_steps
        self.token_selection_strategy = token_selection_strategy
        self.pivot_token_percentage = pivot_token_percentage
        # 为重要tokens创建50步的diffusion
        self.pivot_diffusion = create_diffusion(timestep_respacing=str(pivot_diffusion_steps), noise_schedule="cosine")

    def forward_(self, target, z, mask=None):
        t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
        model_kwargs = dict(c=z)
        loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
        loss = loss_dict["loss"]
        if mask is not None:
            loss = (loss * mask).sum() / mask.sum()
        return loss.mean()

    def forward(self, z, temperature=1.0, cfg=1.0, step=0, ar_num_iter=64, bsz=None):
        # 检查是否使用混合策略
        if step >= self.pivot_step_threshold:
            return self._sample_with_pivot_strategy(z, temperature, cfg, step, ar_num_iter, bsz)
        else:
            return self._sample_original_strategy(z, temperature, cfg, step, ar_num_iter)

    def sample(self, z, temperature=1.0, cfg=1.0, step=0, ar_num_iter=64, bsz=None):
        # 检查是否使用混合策略
        if step >= self.pivot_step_threshold:
            return self._sample_with_pivot_strategy(z, temperature, cfg, step, ar_num_iter, bsz)
        else:
            return self._sample_original_strategy(z, temperature, cfg, step, ar_num_iter)
    
    def _sample_original_strategy(self, z, temperature=1.0, cfg=1.0, step=0, ar_num_iter=64):
        """原始的采样策略"""
        # diffusion loss sampling
        if not cfg == 1.0:
            noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda()
            noise = torch.cat([noise, noise], dim=0)
            model_kwargs = dict(c=z, cfg_scale=cfg)
            sample_fn = self.net.forward_with_cfg
        else:
            noise = torch.randn(z.shape[0], self.in_channels).cuda()
            model_kwargs = dict(c=z)
            sample_fn = self.net.forward

        upper_step = ar_num_iter - 1
        schedule_id = int((upper_step-step) / upper_step * (len(self.gen_diffusion)-1)) #从1到0
        # steps = [step for step in range(10, 50 + 1, 1)]
        # print(f"ar step: {step}, schedule_id: {schedule_id}, diffusion step: {steps[schedule_id]}")
        sampled_token_latent = self.gen_diffusion[schedule_id].p_sample_loop(
            sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
            temperature=temperature
        )
        return sampled_token_latent
    
    def _sample_with_pivot_strategy(self, z, temperature=1.0, cfg=1.0, step=0, ar_num_iter=64, bsz=None):
        """使用重要token策略的采样方法"""
        device = z.device
        
        # 处理CFG情况
        if not cfg == 1.0:
            # 对于CFG，z的前半部分是条件化的，后半部分是无条件的
            z_cond = z[:z.shape[0] // 2]  # 前半部分
            z_uncond = z[z.shape[0] // 2:]  # 后半部分
            
            # 根据策略计算条件化tokens的重要性分数，然后复制给无条件化部分
            if self.token_selection_strategy == "pivotal":
                importance_scores = pivotal_token_selection(z_cond)
            elif self.token_selection_strategy == "random":
                importance_scores = random_token_selection(z_cond)
            else:
                raise ValueError(f"Unknown token selection strategy: {self.token_selection_strategy}")
            
            # 计算每个数据的token数量
            num_tokens = z_cond.shape[0]
            token_nums_per_data = num_tokens // bsz if bsz is not None else num_tokens
        else:
            # 根据策略计算重要性分数
            if self.token_selection_strategy == "pivotal":
                importance_scores = pivotal_token_selection(z)
            elif self.token_selection_strategy == "random":
                importance_scores = random_token_selection(z)
            else:
                raise ValueError(f"Unknown token selection strategy: {self.token_selection_strategy}")
            
            # 计算每个数据的token数量
            num_tokens = z.shape[0]
            token_nums_per_data = num_tokens // bsz if bsz is not None else num_tokens
        
        # 对每个数据分别选择top X%的重要tokens
        if bsz is not None and bsz > 1:
            # 重塑importance_scores为(bsz, token_nums_per_data)
            scores_per_data = importance_scores.view(bsz, token_nums_per_data)
            
            # 计算每个数据需要选择的token数量
            num_pivot_per_data = max(1, int(token_nums_per_data * self.pivot_token_percentage))
            
            # 对每个数据选择最重要的tokens
            _, pivot_indices_per_data = torch.topk(scores_per_data, num_pivot_per_data, dim=1)  # (bsz, num_pivot_per_data)
            
            # 转换为全局索引
            pivot_indices = []
            for i in range(bsz):
                data_offset = i * token_nums_per_data
                data_pivot_indices = pivot_indices_per_data[i] + data_offset
                pivot_indices.append(data_pivot_indices)
            pivot_indices = torch.cat(pivot_indices, dim=0)  # 合并所有数据的选择结果
        else:
            # 原始逻辑：在所有tokens中选择
            num_pivot_tokens = max(1, int(num_tokens * self.pivot_token_percentage))
            _, pivot_indices = torch.topk(importance_scores, num_pivot_tokens, dim=0)
        
        # 创建掩码
        pivot_mask = torch.zeros(num_tokens, dtype=torch.bool, device=device)
        pivot_mask[pivot_indices] = True
        non_pivot_mask = ~pivot_mask
        
        # 初始化采样结果变量
        sampled_pivot = None
        sampled_non_pivot = None
        
        # 对重要tokens使用50步diffusion
        if pivot_mask.sum() > 0:
            if not cfg == 1.0:
                # CFG情况：需要处理条件化和无条件化的tokens
                z_pivot = torch.cat([z_cond[pivot_mask], z_uncond[pivot_mask]], dim=0)
                noise_pivot = torch.randn(z_pivot.shape[0] // 2, self.in_channels, device=device)
                noise_pivot = torch.cat([noise_pivot, noise_pivot], dim=0)
                model_kwargs_pivot = dict(c=z_pivot, cfg_scale=cfg)
                sample_fn_pivot = self.net.forward_with_cfg
            else:
                z_pivot = z[pivot_mask]
                noise_pivot = torch.randn(z_pivot.shape[0], self.in_channels, device=device)
                model_kwargs_pivot = dict(c=z_pivot)
                sample_fn_pivot = self.net.forward
            
            sampled_pivot = self.pivot_diffusion.p_sample_loop(
                sample_fn_pivot, noise_pivot.shape, noise_pivot, clip_denoised=False, 
                model_kwargs=model_kwargs_pivot, progress=False, temperature=temperature
            )
        
        # 对非重要tokens使用原始策略
        if non_pivot_mask.sum() > 0:
            if not cfg == 1.0:
                z_non_pivot = torch.cat([z_cond[non_pivot_mask], z_uncond[non_pivot_mask]], dim=0)
                noise_non_pivot = torch.randn(z_non_pivot.shape[0] // 2, self.in_channels, device=device)
                noise_non_pivot = torch.cat([noise_non_pivot, noise_non_pivot], dim=0)
                model_kwargs_non_pivot = dict(c=z_non_pivot, cfg_scale=cfg)
                sample_fn_non_pivot = self.net.forward_with_cfg
            else:
                z_non_pivot = z[non_pivot_mask]
                noise_non_pivot = torch.randn(z_non_pivot.shape[0], self.in_channels, device=device)
                model_kwargs_non_pivot = dict(c=z_non_pivot)
                sample_fn_non_pivot = self.net.forward
            
            upper_step = ar_num_iter - 1
            schedule_id = int((upper_step-step) / upper_step * (len(self.gen_diffusion)-1))
            sampled_non_pivot = self.gen_diffusion[schedule_id].p_sample_loop(
                sample_fn_non_pivot, noise_non_pivot.shape, noise_non_pivot, clip_denoised=False,
                model_kwargs=model_kwargs_non_pivot, progress=False, temperature=temperature
            )
        
        # 合并结果 - 注意结果应该是diffusion输出的维度 [num_tokens, self.in_channels]
        if not cfg == 1.0:
            # CFG情况
            result_cond = torch.zeros(num_tokens, self.in_channels, device=device)
            result_uncond = torch.zeros(num_tokens, self.in_channels, device=device)
            
            if pivot_mask.sum() > 0 and sampled_pivot is not None:
                pivot_result_cond, pivot_result_uncond = sampled_pivot.chunk(2, dim=0)
                result_cond[pivot_mask] = pivot_result_cond
                result_uncond[pivot_mask] = pivot_result_uncond
            
            if non_pivot_mask.sum() > 0 and sampled_non_pivot is not None:
                non_pivot_result_cond, non_pivot_result_uncond = sampled_non_pivot.chunk(2, dim=0)
                result_cond[non_pivot_mask] = non_pivot_result_cond
                result_uncond[non_pivot_mask] = non_pivot_result_uncond
            
            result = torch.cat([result_cond, result_uncond], dim=0)
        else:
            result = torch.zeros(num_tokens, self.in_channels, device=device)
            if pivot_mask.sum() > 0 and sampled_pivot is not None:
                result[pivot_mask] = sampled_pivot
            if non_pivot_mask.sum() > 0 and sampled_non_pivot is not None:
                result[non_pivot_mask] = sampled_non_pivot
        
        return result
    



def modulate(x, shift, scale):
    return x * (1 + scale) + shift


class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb


class ResBlock(nn.Module):
    """
    A residual block that can optionally change the number of channels.
    :param channels: the number of input channels.
    """

    def __init__(
        self,
        channels
    ):
        super().__init__()
        self.channels = channels

        self.in_ln = nn.LayerNorm(channels, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels, bias=True),
            nn.SiLU(),
            nn.Linear(channels, channels, bias=True),
        )

        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(channels, 3 * channels, bias=True)
        )

    def forward(self, x, y):
        shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
        h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
        h = self.mlp(h)
        return x + gate_mlp * h


class FinalLayer(nn.Module):
    """
    The final layer adopted from DiT.
    """
    def __init__(self, model_channels, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(model_channels, out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(model_channels, 2 * model_channels, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class SimpleMLPAdaLN(nn.Module):
    """
    The MLP for Diffusion Loss.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param z_channels: channels in the condition.
    :param num_res_blocks: number of residual blocks per downsample.
    """

    def __init__(
        self,
        in_channels,
        model_channels,
        out_channels,
        z_channels,
        num_res_blocks,
        grad_checkpointing=False
    ):
        super().__init__()

        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.grad_checkpointing = grad_checkpointing

        self.time_embed = TimestepEmbedder(model_channels)
        self.cond_embed = nn.Linear(z_channels, model_channels)

        self.input_proj = nn.Linear(in_channels, model_channels)

        res_blocks = []
        for i in range(num_res_blocks):
            res_blocks.append(ResBlock(
                model_channels,
            ))

        self.res_blocks = nn.ModuleList(res_blocks)
        self.final_layer = FinalLayer(model_channels, out_channels)

        self.initialize_weights()

    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

        # Initialize timestep embedding MLP
        nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
        nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers
        for block in self.res_blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def forward(self, x, t, c):
        """
        Apply the model to an input batch.
        :param x: an [N x C] Tensor of inputs.
        :param t: a 1-D batch of timesteps.
        :param c: conditioning from AR transformer.
        :return: an [N x C] Tensor of outputs.
        """
        x = self.input_proj(x)
        t = self.time_embed(t)
        c = self.cond_embed(c)

        y = t + c

        if self.grad_checkpointing and not torch.jit.is_scripting():
            for block in self.res_blocks:
                x = checkpoint(block, x, y)
        else:
            for block in self.res_blocks:
                x = block(x, y)

        return self.final_layer(x, y)

    def forward_with_cfg(self, x, t, c, cfg_scale):
        half = x[: len(x) // 2]
        combined = torch.cat([half, half], dim=0)
        model_out = self.forward(combined, t, c)
        eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
        cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
        half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)
        return torch.cat([eps, rest], dim=1)
