import math
import torch
import einops
import torch.nn as nn
import torch.nn.functional as F
from diffusers.optimization import get_scheduler
from huggingface_hub import PyTorchModelHubMixin

from flare.factory import registry
from flare.policies import BasePolicy
from flare.policies.observers.resnet_observer import ResNetObserver
from flare.utils.normalize import Normalize, Unnormalize


@registry.register_policy("act")
class ActPolicy(BasePolicy):
    def __init__(self, config, stats):
        super().__init__(config, stats)

        self.config = config
        policy_config = config.policy
        self.stats = stats

        self.pred_horizon = policy_config.pred_horizon
        self.action_horizon = policy_config.action_horizon
        self.action_dim = self.config.task.action_dim

        self.normalize_inputs = Normalize(self.config.task.image_keys + [self.config.task.state_key], stats)
        self.normalize_targets = Normalize([self.config.task.action_key], stats)
        self.unnormalize_outputs = Unnormalize([self.config.task.action_key], stats)
        self._action_queue = None

        # Initialize observer from config
        self.observer = ResNetObserver(
            state_key=config.task.state_key,
            image_keys=config.task.image_keys,
            resize_shape=config.resize_shape,
            crop_shape=config.crop_shape,
            state_dim=config.task.state_dim,
            tokenize=policy_config.observer.tokenize,
        )
        self.obs_dim = len(self.config.task.image_keys) * 512 + self.config.task.state_dim

        # Initialize ACT model from transformer config
        self.act = ACT(
            **policy_config.transformer,
            chunk_size=self.pred_horizon,
            obs_dim=self.obs_dim,
            action_dim=self.action_dim,
        )

        self.reset()

    def compute_loss(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict]:
        """Compute ACT loss including L1 and optional KL divergence."""
        obs_features = self.observer(batch)
        target_actions = batch[self.config.task.action_key]

        # Create ACT batch format
        act_batch = {
            "obs_features": obs_features,
            "action": target_actions,
        }

        # Add action padding mask if available
        pad_key = f"{self.config.task.action_key}_is_pad"
        if pad_key in batch:
            act_batch["action_is_pad"] = batch[pad_key]
        else:
            act_batch["action_is_pad"] = torch.zeros(
                target_actions.shape[:2], dtype=torch.bool, device=target_actions.device
            )

        # Forward pass through ACT
        actions_hat, (mu_hat, log_sigma_x2_hat) = self.act(act_batch)

        # Compute L1 loss, masking padded actions
        l1_loss = (
            F.l1_loss(target_actions, actions_hat, reduction="none") *
            ~act_batch["action_is_pad"].unsqueeze(-1)
        ).mean()

        metrics = {"l1_loss": l1_loss.item()}
        total_loss = l1_loss

        # Add KL divergence loss if using VAE
        if self.config.policy.transformer.use_vae and mu_hat is not None:
            mean_kld = (
                (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - log_sigma_x2_hat.exp()))
                .sum(-1).mean()
            )
            metrics["kld_loss"] = mean_kld.item()
            total_loss += mean_kld * self.config.policy.act.kl_weight

        return total_loss, metrics

    def generate_actions(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
        """Generate action sequence from observations."""
        obs_features = self.observer(batch)
        act_batch = {"obs_features": obs_features}
        actions = self.act(act_batch)[0][:, :]
        return actions

    @torch.no_grad
    def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
        self.eval()
        batch = {k: v.unsqueeze(1) for k, v in batch.items()
                 if k in self.config.task.image_keys + [self.config.task.state_key]}
        batch = self.normalize_inputs(batch)

        if len(self._action_queue) == 0:
            actions_pred_norm = self.generate_actions(batch)
            actions_pred_norm = actions_pred_norm[:, :self.action_horizon]

            action_key = self.config.task.action_key
            actions_pred = self.unnormalize_outputs({action_key: actions_pred_norm})[action_key]

            self._action_queue.extend(actions_pred.transpose(0, 1))

        return self._action_queue.popleft()

    def get_optimizer(self) -> torch.optim.Optimizer:
        opt_config = self.config.policy.optimizer
        observer_params = []
        other_params = []

        for name, param in self.named_parameters():
            if "observer" in name:
                observer_params.append(param)
            else:
                other_params.append(param)
        print(f"Observer params: {len(observer_params)}")
        print(f"Other policy params: {len(other_params)}")

        return torch.optim.AdamW([
            {"params": other_params, "lr": opt_config.lr},
            {"params": observer_params, "lr": opt_config.lr * opt_config.backbone_lr_scale}
        ], weight_decay=opt_config.weight_decay, betas=opt_config.betas, eps=opt_config.eps)

    def get_scheduler(self, optimizer: torch.optim.Optimizer, num_training_steps: int):
        sched_config = self.config.policy.scheduler
        return get_scheduler(
            name=sched_config.name,
            optimizer=optimizer,
            num_warmup_steps=sched_config.warmup_steps,
            num_training_steps=num_training_steps,
        )


class ACT(nn.Module):
    def __init__(
        self,
        dim_model: int,
        n_encoder_layers: int,
        n_vae_encoder_layers: int,
        n_decoder_layers: int,
        n_heads: int,
        dim_feedforward: int,
        dropout: float,
        feedforward_activation: str,
        pre_norm: bool,
        use_vae: bool,
        latent_dim: int,
        chunk_size: int,
        obs_dim: int,
        action_dim: int,
    ):
        super().__init__()
        self.use_vae = use_vae
        self.latent_dim = latent_dim
        self.chunk_size = chunk_size
        self.dim_model = dim_model
        self.obs_dim = obs_dim
        self.action_dim = action_dim

        # VAE encoder
        if self.use_vae:
            self.vae_encoder = ACTEncoder(
                dim_model, n_encoder_layers, n_vae_encoder_layers, n_heads,
                dim_feedforward, dropout, feedforward_activation, pre_norm, is_vae_encoder=True
            )
            self.vae_encoder_cls_embed = nn.Embedding(1, dim_model)
            self.vae_encoder_obs_input_proj = nn.Linear(obs_dim, dim_model)
            self.vae_encoder_action_input_proj = nn.Linear(action_dim, dim_model)
            self.vae_encoder_latent_output_proj = nn.Linear(dim_model, latent_dim * 2)

            num_input_tokens = 1 + 1 + chunk_size
            self.register_buffer(
                "vae_encoder_pos_enc",
                create_sinusoidal_pos_embedding(num_input_tokens, dim_model).unsqueeze(0),
            )

        # Transformer encoder/decoder
        self.encoder = ACTEncoder(
            dim_model, n_encoder_layers, n_vae_encoder_layers, n_heads,
            dim_feedforward, dropout, feedforward_activation, pre_norm
        )
        self.decoder = ACTDecoder(
            dim_model, n_decoder_layers, n_heads, dim_feedforward,
            dropout, feedforward_activation, pre_norm
        )

        self.encoder_latent_input_proj = nn.Linear(latent_dim, dim_model)
        self.encoder_obs_input_proj = nn.Linear(self.obs_dim, dim_model)

        n_tokens = 2
        self.encoder_pos_embed = nn.Embedding(n_tokens, dim_model)
        self.decoder_pos_embed = nn.Embedding(chunk_size, dim_model)
        self.action_head = nn.Linear(dim_model, action_dim)
        self._init_weights()

    def _init_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.LayerNorm):
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
                if module.weight is not None:
                    nn.init.constant_(module.weight, 1.0)
        self.apply(_basic_init)

    def forward(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        obs_features = batch["obs_features"]
        batch_size = obs_features.shape[0]

        if self.use_vae and "action" in batch and self.training:
            mu, log_sigma_x2, latent_sample = self._vae_encode(batch, obs_features)
        else:
            mu = log_sigma_x2 = None
            latent_sample = torch.zeros(
                (batch_size, self.latent_dim),
                dtype=torch.float32,
                device=obs_features.device
            )

        actions = self._transformer_forward(obs_features, latent_sample)
        return actions, (mu, log_sigma_x2)

    def _vae_encode(self, batch: dict[str, torch.Tensor], obs_features: torch.Tensor):
        batch_size = obs_features.shape[0]
        cls_embed = einops.repeat(self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size)
        obs_embed = self.vae_encoder_obs_input_proj(obs_features).unsqueeze(1)
        action_embed = self.vae_encoder_action_input_proj(batch["action"])
        vae_input = torch.cat([cls_embed, obs_embed, action_embed], dim=1)
        pos_embed = self.vae_encoder_pos_enc.clone().detach()
        cls_obs_is_pad = torch.full((batch_size, 2), False, device=obs_features.device, dtype=torch.bool)
        key_padding_mask = torch.cat([cls_obs_is_pad, batch["action_is_pad"]], dim=1)
        cls_output = self.vae_encoder(
            vae_input.transpose(0, 1),
            pos_embed=pos_embed.transpose(0, 1),
            key_padding_mask=key_padding_mask,
        )[0]
        latent_params = self.vae_encoder_latent_output_proj(cls_output)
        mu, log_sigma_x2 = latent_params.chunk(2, dim=-1)
        latent_sample = mu + (log_sigma_x2 / 2).exp() * torch.randn_like(mu)
        return mu, log_sigma_x2, latent_sample

    def _transformer_forward(self, obs_features: torch.Tensor, latent_sample: torch.Tensor):
        batch_size = obs_features.shape[0]
        latent_token = self.encoder_latent_input_proj(latent_sample)
        obs_token = self.encoder_obs_input_proj(obs_features)
        encoder_tokens = torch.stack([latent_token, obs_token], dim=0)
        encoder_pos_embed = self.encoder_pos_embed.weight.unsqueeze(1)
        encoder_out = self.encoder(encoder_tokens, pos_embed=encoder_pos_embed)
        decoder_input = torch.zeros(
            (self.chunk_size, batch_size, self.dim_model),
            dtype=encoder_tokens.dtype, device=encoder_tokens.device,
        )
        decoder_pos_embed = self.decoder_pos_embed.weight.unsqueeze(1)
        decoder_out = self.decoder(
            decoder_input, encoder_out,
            decoder_pos_embed=decoder_pos_embed,
            encoder_pos_embed=encoder_pos_embed,
        )
        decoder_out = decoder_out.transpose(0, 1)
        actions = self.action_head(decoder_out)
        return actions


class ACTEncoder(nn.Module):
    """Multi-layer transformer encoder."""

    def __init__(
        self,
        dim_model: int,
        n_encoder_layers: int,
        n_vae_encoder_layers: int,
        n_heads: int,
        dim_feedforward: int,
        dropout: float,
        feedforward_activation: str,
        pre_norm: bool,
        is_vae_encoder: bool = False
    ):
        super().__init__()
        self.is_vae_encoder = is_vae_encoder
        num_layers = n_vae_encoder_layers if is_vae_encoder else n_encoder_layers

        self.layers = nn.ModuleList([
            ACTEncoderLayer(dim_model, n_heads, dim_feedforward, dropout,
                            feedforward_activation, pre_norm)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(dim_model) if pre_norm else nn.Identity()

    def forward(
        self,
        x: torch.Tensor,
        pos_embed: torch.Tensor | None = None,
        key_padding_mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
        return self.norm(x)


class ACTEncoderLayer(nn.Module):
    """Single transformer encoder layer."""

    def __init__(
        self,
        dim_model: int,
        n_heads: int,
        dim_feedforward: int,
        dropout: float,
        feedforward_activation: str,
        pre_norm: bool
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False)

        self.linear1 = nn.Linear(dim_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, dim_model)

        self.norm1 = nn.LayerNorm(dim_model)
        self.norm2 = nn.LayerNorm(dim_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = get_activation_fn(feedforward_activation)
        self.pre_norm = pre_norm

    def forward(
        self,
        x: torch.Tensor,
        pos_embed: torch.Tensor | None = None,
        key_padding_mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        # Self-attention block
        skip = x
        if self.pre_norm:
            x = self.norm1(x)
        q = k = x if pos_embed is None else x + pos_embed
        x, _ = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)
        x = skip + self.dropout1(x)

        # Feed-forward block
        if not self.pre_norm:
            x = self.norm1(x)

        skip = x
        if self.pre_norm:
            x = self.norm2(x)
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        x = skip + self.dropout2(x)

        if not self.pre_norm:
            x = self.norm2(x)

        return x


class ACTDecoder(nn.Module):
    def __init__(
        self,
        dim_model: int,
        n_decoder_layers: int,
        n_heads: int,
        dim_feedforward: int,
        dropout: float,
        feedforward_activation: str,
        pre_norm: bool
    ):
        super().__init__()
        self.layers = nn.ModuleList([ACTDecoderLayer(
            dim_model,
            n_heads,
            dim_feedforward,
            dropout,
            feedforward_activation,
            pre_norm
        ) for _ in range(n_decoder_layers)])
        self.norm = nn.LayerNorm(dim_model)

    def forward(
        self,
        x: torch.Tensor,
        encoder_out: torch.Tensor,
        decoder_pos_embed: torch.Tensor | None = None,
        encoder_pos_embed: torch.Tensor | None = None,
    ) -> torch.Tensor:
        for layer in self.layers:
            x = layer(
                x,
                encoder_out,
                decoder_pos_embed=decoder_pos_embed,
                encoder_pos_embed=encoder_pos_embed
            )
        if self.norm is not None:
            x = self.norm(x)
        return x


class ACTDecoderLayer(nn.Module):
    def __init__(
        self,
        dim_model: int,
        n_heads: int,
        dim_feedforward: int,
        dropout: float,
        feedforward_activation: str,
        pre_norm: bool
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False)
        self.multihead_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False)

        self.linear1 = nn.Linear(dim_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, dim_model)

        self.norm1 = nn.LayerNorm(dim_model)
        self.norm2 = nn.LayerNorm(dim_model)
        self.norm3 = nn.LayerNorm(dim_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = get_activation_fn(feedforward_activation)
        self.pre_norm = pre_norm

    def maybe_add_pos_embed(self, tensor: torch.Tensor, pos_embed: torch.Tensor | None) -> torch.Tensor:
        return tensor if pos_embed is None else tensor + pos_embed

    def forward(
        self,
        x: torch.Tensor,
        encoder_out: torch.Tensor,
        decoder_pos_embed: torch.Tensor | None = None,
        encoder_pos_embed: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # Self-attention
        skip = x
        if self.pre_norm:
            x = self.norm1(x)
        q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)
        x = self.self_attn(q, k, value=x)[0]  # select just the output, not the attention weights
        x = skip + self.dropout1(x)
        if self.pre_norm:
            skip = x
            x = self.norm2(x)
        else:
            x = self.norm1(x)
            skip = x
        x = self.multihead_attn(
            query=self.maybe_add_pos_embed(x, decoder_pos_embed),
            key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),
            value=encoder_out,
        )[0]  # select just the output, not the attention weights
        x = skip + self.dropout2(x)
        if self.pre_norm:
            skip = x
            x = self.norm3(x)
        else:
            x = self.norm2(x)
            skip = x
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        x = skip + self.dropout3(x)
        if not self.pre_norm:
            x = self.norm3(x)
        return x


def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> torch.Tensor:
    """Create 1D sinusoidal positional embeddings."""
    inv_freq = 1.0 / (10000 ** (torch.arange(0, dimension, 2).float() / dimension))
    position = torch.arange(num_positions, dtype=torch.float32).unsqueeze(1)
    pos_embed = torch.zeros(num_positions, dimension)
    pos_embed[:, 0::2] = torch.sin(position * inv_freq)
    pos_embed[:, 1::2] = torch.cos(position * inv_freq)
    return pos_embed


class ACTSinusoidalPositionEmbedding2d(nn.Module):
    """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.

    The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
    for the vertical direction, and 1/W for the horizontal direction.
    """

    def __init__(self, dimension: int):
        """
        Args:
            dimension: The desired dimension of the embeddings.
        """
        super().__init__()
        self.dimension = dimension
        self._two_pi = 2 * math.pi
        self._eps = 1e-6
        # Inverse "common ratio" for the geometric progression in sinusoid frequencies.
        self._temperature = 10000

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.
        Returns:
            A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.
        """
        not_mask = torch.ones_like(x[0, :1])  # (1, H, W)
        # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations
        # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.
        y_range = not_mask.cumsum(1, dtype=torch.float32)
        x_range = not_mask.cumsum(2, dtype=torch.float32)

        # "Normalize" the position index such that it ranges in [0, 2π].
        # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range
        # are non-zero by construction. This is an artifact of the original code.
        y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi
        x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi

        inverse_frequency = self._temperature ** (
            2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension
        )

        x_range = x_range.unsqueeze(-1) / inverse_frequency  # (1, H, W, 1)
        y_range = y_range.unsqueeze(-1) / inverse_frequency  # (1, H, W, 1)

        # Note: this stack then flatten operation results in interleaved sine and cosine terms.
        # pos_embed_x and pos_embed_y are (1, H, W, C // 2).
        pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3)
        pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3)
        pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2)  # (1, C, H, W)

        return pos_embed


def get_activation_fn(activation: str):
    """Get activation function by name."""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.")
