from torch import nn
import torch
from torch.nn import functional as F
from functools import partial
from timm.layers import SwiGLU
from timm.models.vision_transformer import DropPath
from algos.components.ActionHead import MPScalseFlowhead

class RMSNorm(torch.nn.Module):
    def __init__(self, dim, eps: float = 1e-6, weight=False):
        super().__init__()
        self.eps = eps
        if weight:
            self.weight = nn.Parameter(torch.ones(dim))
        else:
            self.weight=None

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        if self.weight is None:
            return output
        else:
            return output * self.weight

def rotate_half(x):

    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).reshape_as(x)

class ActionRotaryEmbeddingFast(nn.Module):

    def __init__(self, dim, pt_seq_len=16, base=10000,
                 precision=torch.float32, learnable=False):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)).to(precision)

        self.pt_seq_len = pt_seq_len
        self.learnable = learnable
        self.precision = precision

        if learnable:
            self.inv_freq = nn.Parameter(inv_freq)
        else:
            self.register_buffer('inv_freq', inv_freq)

        self.max_seq_len_cached = 0
        self.cos_cached = None
        self.sin_cached = None

    def _update_cache(self, seq_len, device, dtype):
        t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) / self.pt_seq_len
        freqs = torch.einsum('i,j->ij', t, self.inv_freq).repeat_interleave(2, dim=-1)
        cos = freqs.cos().to(dtype)
        sin = freqs.sin().to(dtype)
        self.cos_cached = cos
        self.sin_cached = sin
        self.max_seq_len_cached = seq_len

    def forward(self, x):

        if x.dim() != 4:
            raise ValueError(f"期望 4D 张量 [B, H, T, D]，但得到 {x.dim()}D")

        B, H, T, D = x.shape

        if self.max_seq_len_cached < T:
            self._update_cache(T, x.device, x.dtype)

        cos = self.cos_cached[:T].unsqueeze(0).unsqueeze(0)
        sin = self.sin_cached[:T].unsqueeze(0).unsqueeze(0)

        return x * cos + rotate_half(x) * sin

    def _apply(self, fn):

        if self.cos_cached is not None:
            self.cos_cached = fn(self.cos_cached)
        if self.sin_cached is not None:
            self.sin_cached = fn(self.sin_cached)
        return super()._apply(fn)

class Attention(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
            scale=None,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = attn_drop
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.rope = ActionRotaryEmbeddingFast(
            dim=self.head_dim,
            pt_seq_len=max(scale),
        )
        self.scales = scale
        self.k, self.v = None, None

    def clear_cache(self):
        self.k, self.v = None, None

    def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        sequence = [0] + list(self.scales)
        sequence = torch.cumsum(torch.tensor(sequence), dim=0)

        if self.training:
            q = torch.cat([self.rope(q[:, :, sequence[i]:sequence[ i +1]]) for i in range(len(self.scales))], dim=2)
            k = torch.cat([self.rope(k[:, :, sequence[i]:sequence[ i +1]]) for i in range(len(self.scales))], dim=2)
            x = F.scaled_dot_product_attention(
                q, k, v, attn_mask=mask,
                dropout_p=self.attn_drop if self.training else 0.,
            )
        else:
            q = self.rope(q)
            k = self.rope(k)
            if self.k is None or self.v is None:
                self.k, self.v = k, v
            else:
                self.k = torch.cat([self.k, k], dim=2)
                self.v = torch.cat([self.v, v], dim=2)
            x = F.scaled_dot_product_attention(q, self.k, self.v)

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int,
            mlp_ratio: float = 4.,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            proj_drop: float = 0.,
            attn_drop: float = 0.,
            # init_values: Optional[float] = None,
            drop_path: float = 0.,
            act_layer: nn.Module = nn.GELU,
            norm_layer: nn.Module = nn.LayerNorm,
            mlp_layer: nn.Module = SwiGLU,
            scale=None,
            use_checkpoint=False
    ) -> None:
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
            scale=scale,
        )
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = RMSNorm(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio * 2 /3.),
            act_layer=act_layer,
            drop=proj_drop
        )
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(dim, 6* dim))

        self.dim = dim
        self.use_checkpoint = use_checkpoint

    def forward(self, x: torch.Tensor, condition, mask) -> torch.Tensor:
        if self.use_checkpoint:
            return torch.utils.checkpoint.checkpoint(self._forward, x, condition, mask)
        else:
            return self._forward(x, condition, mask)

    def _forward(self, x: torch.Tensor, condition, mask) -> torch.Tensor:

        gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(condition).view(-1, 1, 6, self.dim).unbind(2)

        x = x + self.drop_path1(self.attn(self.norm1(x).mul(scale1.add(1)).add_(shift1), mask).mul_(gamma1))
        x = x + self.drop_path2(self.mlp(self.norm2(x).mul(scale2.add(1)).add_(shift2)).mul_(gamma2))
        return x

class FlowAR(nn.Module):

    def __init__(self,
                 encoder_embed_dim=256, encoder_depth=6, encoder_num_heads=4,
                 decoder_embed_dim=256, decoder_depth=6, decoder_num_heads=4,
                 mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6),
                 action_dim=32,
                 attn_dropout=0.,
                 proj_dropout=0.,
                 scale=(1, 2, 4, 8),
                 obs_in_features=1024,
                 ):
        super().__init__()

        self.scale = list(scale)
        print(f'latent action scale :{scale}')

        self.z_proj = nn.Linear(action_dim, encoder_embed_dim)
        self.seq_len = sum(self.scale)
        self.encoder_pos_embed_learned = nn.Parameter(
            torch.zeros(1, self.seq_len, encoder_embed_dim))
        self.z_proj_ln = RMSNorm(encoder_embed_dim, weight=True)

        self.token_embed_dim = action_dim

        attention_mask = []
        start = 0
        total_length = sum(self.scale)
        for idx, pz in enumerate(self.scale):
            pz = pz
            start += pz
            attention_mask.append(torch.cat([torch.ones((pz, start)),
                                             torch.zeros((pz, total_length - start))], dim=-1))

        attention_mask = torch.cat(attention_mask, dim=0)
        attention_mask = torch.where(attention_mask == 0, -torch.inf, attention_mask)
        attention_mask = torch.where(attention_mask == 1, 0, attention_mask)
        attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
        self.register_buffer('mask', attention_mask.contiguous())

        self.encoder_blocks = nn.ModuleList([
            Block(dim=encoder_embed_dim, num_heads=encoder_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True,
                  norm_layer=norm_layer,proj_drop=proj_dropout, attn_drop=attn_dropout, scale=scale,
                  ) for _ in range(encoder_depth)])
        self.encoder_norm = RMSNorm(encoder_embed_dim, weight=True)

        self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
        self.decoder_pos_embed_learned = nn.Parameter(
            torch.zeros(1, self.seq_len, decoder_embed_dim))

        self.decoder_blocks = nn.ModuleList([
            Block(dim=decoder_embed_dim, num_heads=decoder_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True,
                  norm_layer=norm_layer,proj_drop=proj_dropout, attn_drop=attn_dropout, scale=scale,
                  ) for _ in range(decoder_depth)])

        self.decoder_norm = RMSNorm(decoder_embed_dim, weight=True)

        self.decoder_norm = RMSNorm(decoder_embed_dim, weight=True)
        self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))

        self.fusion_obs = nn.Linear(in_features=obs_in_features, out_features=encoder_embed_dim)

        self.flownet = MPScalseFlowhead(
            action_size=action_dim,
            num_blocks=3,
            input_dim=encoder_embed_dim,
            hidden_dim=encoder_embed_dim,
            time_hidden_dim=encoder_embed_dim,
        )

        self.initialize_weights()

    def initialize_weights(self):

        torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
        torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
        torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
            if m.weight is not None:
                nn.init.constant_(m.weight, 1.0)

    def interpolate_1d(self, x, target_length, mode='linear'):
        return F.interpolate(x.permute(0, 2, 1), size=target_length, mode=mode).permute(0, 2, 1)

    def forward_mae_encoder(self, x, condition, mask, start=None, end=None, training=True):
        if training:
            encoder_pos_embed_learned = self.encoder_pos_embed_learned
        else:
            encoder_pos_embed_learned = self.encoder_pos_embed_learned[:, start:end]

        x = x + encoder_pos_embed_learned
        x = self.z_proj_ln(x)

        for blk in self.encoder_blocks:
            x = blk(x, condition, mask)
        x = self.encoder_norm(x)

        return x

    def forward_mae_decoder(self, x, condition, mask, start=None, end=None, training=True):
        x = self.decoder_embed(x)
        if training:
            decoder_pos_embed_learned = self.decoder_pos_embed_learned
        else:
            decoder_pos_embed_learned = self.decoder_pos_embed_learned[:, start:end]
        x = x + decoder_pos_embed_learned

        for blk in self.decoder_blocks:
            x = blk(x, condition, mask)
        x = self.decoder_norm(x)
        if training:
            diffusion_pos_embed_learned = self.diffusion_pos_embed_learned
        else:
            diffusion_pos_embed_learned = self.diffusion_pos_embed_learned[:, start:end]

        x = x + diffusion_pos_embed_learned
        return x

    def forward(self, act, obs_emb=None):
        B,_,_ = act.shape


        gt_latents = [act.detach()]
        for i in self.scale[::-1][1:]:
            gt_latents.append(self.interpolate_1d(act.detach(), i))
        gt_latents = gt_latents[::-1]
        next_scale = self.scale[1:]

        x_input = []
        for idx, target_T in enumerate(next_scale):
            latent = gt_latents[idx].detach()

            latent_upsampled = self.interpolate_1d(latent, target_length=self.scale[-1], mode='linear')

            latent_downsampled = self.interpolate_1d(latent_upsampled, target_length=target_T, mode='linear')
            x_input.append(latent_downsampled)

        x_input = torch.cat(x_input, dim=1)
        gt_latents = torch.cat(gt_latents, dim=1)

        x_input = self.z_proj(x_input)

        context = obs_emb.reshape(B, -1)
        cls = self.fusion_obs(context).unsqueeze(1)

        x_input = torch.cat([cls, x_input], dim=1)

        x = self.forward_mae_encoder(x_input, cls, self.mask, training=True)

        z = self.forward_mae_decoder(x, cls, self.mask, training=True)

        max_scale = self.scale[-1]
        loss = []
        l_dict = {}
        start = 0
        for i in self.scale:
            l,l_dict = self.flownet(gt_latents[:, start:start + i], z[:, start:start + i,...])
            start += i
            loss.append(l * i / max_scale)

        flow_loss = sum(loss)

        return flow_loss , l_dict

    def ema_update(self):
        self.ema.step(self.flownet.parameters())

    def sample_tokens(self, obs_emb=None, training=False):

        B, _ = obs_emb.shape

        context = obs_emb.reshape(B, -1)
        x = self.fusion_obs(context).unsqueeze(1)
        class_embedding = x.clone()

        indices = list(range(len(self.scale)))

        sequence = [i for i in self.scale]
        sequence = torch.cumsum(torch.tensor(sequence), dim=0)
        starts = torch.cat([torch.tensor([0]), sequence], dim=0)
        for blk in self.encoder_blocks:
            blk.attn.clear_cache()
        for blk in self.decoder_blocks:
            blk.attn.clear_cache()
        for step in indices:
            start = starts[step]
            end = sequence[step]
            z = self.forward_mae_encoder(x, class_embedding, None, start, end, training)
            z = self.forward_mae_decoder(z, class_embedding, None, start, end, training)

            sample = torch.randn([z.shape[0], z.shape[1], self.token_embed_dim]).cuda()

            z_sample = self.flownet.generate(z, sample)

            if step == len(self.scale) - 1:
                break

            x_ = z_sample.detach()

            x_ = self.interpolate_1d(x_, target_length=self.scale[-1], mode='linear')

            x_ = self.interpolate_1d(x_, target_length=self.scale[step + 1], mode='linear')

            x = self.z_proj(x_)
        tokens = z_sample


        return tokens

