import math
import torch
import einops
import torch.nn as nn
import torch.nn.functional as F

from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.optimization import get_scheduler
from huggingface_hub import PyTorchModelHubMixin

from flare.factory import registry
from flare.policies import BasePolicy
from flare.policies.observers.resnet_observer import ResNetObserver
from flare.utils.normalize import Normalize, Unnormalize, NormalizationMode


def _make_noise_scheduler(name: str, **kwargs) -> DDPMScheduler | DDIMScheduler:
    if name == "DDPM":
        return DDPMScheduler(**kwargs)
    elif name == "DDIM":
        return DDIMScheduler(**kwargs)
    else:
        raise ValueError(f"Unsupported noise scheduler type {name}")


@registry.register_policy("diffusion")
class DiffusionPolicy(BasePolicy):
    def __init__(self, config, stats):
        super().__init__(config, stats)
        self.config = config
        policy_config = config.policy
        self.stats = stats

        self.pred_horizon = policy_config.pred_horizon
        self.action_horizon = policy_config.action_horizon
        self.action_dim = self.config.task.action_dim

        self.normalize_inputs = Normalize(
            self.config.task.image_keys + [self.config.task.state_key],
            stats,
            norm_map={
                self.config.task.state_key: NormalizationMode.MIN_MAX,
                self.config.task.image_keys: NormalizationMode.MEAN_STD,
            }
        )
        self.normalize_targets = Normalize(
            [self.config.task.action_key],
            stats,
            norm_map={
                self.config.task.action_key: NormalizationMode.MIN_MAX,
            }
        )
        self.unnormalize_outputs = Unnormalize(
            [self.config.task.action_key],
            stats,
            norm_map={
                self.config.task.state_key: NormalizationMode.MIN_MAX,
                self.config.task.image_keys: NormalizationMode.MEAN_STD,
                self.config.task.action_key: NormalizationMode.MIN_MAX,
            }
        )
        self._action_queue = None

        self.observer = ResNetObserver(
            state_key=config.task.state_key,
            image_keys=config.task.image_keys,
            resize_shape=config.resize_shape,
            crop_shape=config.crop_shape,
            state_dim=config.task.state_dim,
            tokenize=policy_config.observer.tokenize,
        )
        self.obs_dim = (
            len(self.config.task.image_keys) * 512 + self.config.task.state_dim
        )
        global_cond_dim = self.obs_dim

        # Denoising network (UNet)
        self.unet = DiffusionConditionalUnet1d(
            input_dim=self.action_dim,
            global_cond_dim=global_cond_dim,
            config=policy_config.unet,
        )

        # Noise scheduler from diffusers
        ns_config = policy_config.noise_scheduler
        self.noise_scheduler = _make_noise_scheduler(
            name=ns_config.type,
            num_train_timesteps=ns_config.num_train_timesteps,
            beta_start=ns_config.beta_start,
            beta_end=ns_config.beta_end,
            beta_schedule=ns_config.beta_schedule,
            clip_sample=ns_config.clip_sample,
            clip_sample_range=ns_config.clip_sample_range,
            prediction_type=ns_config.prediction_type,
        )

        # Inference steps
        num_inference_steps = policy_config.inference.num_steps
        if num_inference_steps is None:
            self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
        else:
            self.num_inference_steps = num_inference_steps

        self.reset()

    def compute_loss(
        self, batch: dict[str, torch.Tensor]
    ) -> tuple[torch.Tensor, dict]:
        # Get features and actions
        obs_features = self.observer(batch)
        global_cond = obs_features.flatten(start_dim=1)
        target_actions = batch[self.config.task.action_key]

        # Sample noise and timesteps for forward diffusion
        noise = torch.randn(target_actions.shape, device=target_actions.device)
        timesteps = torch.randint(
            low=0,
            high=self.noise_scheduler.config.num_train_timesteps,
            size=(target_actions.shape[0],),
            device=target_actions.device,
        ).long()

        # Add noise to clean actions
        noisy_actions = self.noise_scheduler.add_noise(target_actions, noise, timesteps)

        # Predict noise or original sample from noisy actions
        pred = self.unet(noisy_actions, timesteps, global_cond=global_cond)

        # Determine target for loss calculation
        if self.config.policy.noise_scheduler.prediction_type == "epsilon":
            target = noise
        elif self.config.policy.noise_scheduler.prediction_type == "sample":
            target = target_actions
        else:
            raise ValueError(
                f"Unsupported prediction type: {self.config.policy.noise_scheduler.prediction_type}"
            )

        loss = F.mse_loss(pred, target, reduction="none")

        # Mask loss for padded actions
        pad_key = f"{self.config.task.action_key}_is_pad"
        if self.config.policy.mask_loss_for_padding and pad_key in batch:
            in_episode_mask = ~batch[pad_key]
            loss = loss * in_episode_mask.unsqueeze(-1)

        mean_loss = loss.mean()
        metrics = {"mse_loss": mean_loss.item()}
        return mean_loss, metrics

    def generate_actions(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
        batch_size = next(iter(batch.values())).shape[0]
        device = next(self.parameters()).device
        dtype = next(self.parameters()).dtype

        # Get observation features to condition on
        obs_features = self.observer(batch)
        global_cond = obs_features.flatten(start_dim=1)

        # Sample initial noise
        sample = torch.randn(
            size=(batch_size, self.pred_horizon, self.action_dim),
            dtype=dtype,
            device=device,
        )

        # Set inference timesteps and denoise
        self.noise_scheduler.set_timesteps(self.num_inference_steps)
        for t in self.noise_scheduler.timesteps:
            model_output = self.unet(
                sample,
                torch.full((batch_size,), t, dtype=torch.long, device=sample.device),
                global_cond=global_cond,
            )
            sample = self.noise_scheduler.step(model_output, t, sample).prev_sample

        return sample

    @torch.no_grad()
    def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
        self.eval()
        obs_keys = self.config.task.image_keys + [self.config.task.state_key]
        seq_batch = {}
        for key in obs_keys:
            if key not in batch:
                continue
            value = batch[key].unsqueeze(1)
            seq_batch[key] = value

        seq_batch = self.normalize_inputs(seq_batch)

        if len(self._action_queue) == 0:
            actions_pred_norm = self.generate_actions(seq_batch)
            actions_pred_norm = actions_pred_norm[:, : self.action_horizon]
            action_key = self.config.task.action_key
            actions_pred = self.unnormalize_outputs(
                {action_key: actions_pred_norm}
            )[action_key]
            self._action_queue.extend(actions_pred.transpose(0, 1))

        return self._action_queue.popleft()

    def get_optimizer(self) -> torch.optim.Optimizer:
        opt_config = self.config.policy.optimizer
        observer_params = [p for n, p in self.named_parameters() if "observer" in n]
        other_params = [p for n, p in self.named_parameters() if "observer" not in n]

        return torch.optim.AdamW(
            [
                {"params": other_params, "lr": opt_config.lr},
                {
                    "params": observer_params,
                    "lr": opt_config.lr * opt_config.backbone_lr_scale,
                },
            ],
            betas=opt_config.betas,
            eps=opt_config.eps,
            weight_decay=opt_config.weight_decay,
        )

    def get_scheduler(
        self, optimizer: torch.optim.Optimizer, num_training_steps: int
    ):
        sched_config = self.config.policy.scheduler
        return get_scheduler(
            name=sched_config.name,
            optimizer=optimizer,
            num_warmup_steps=sched_config.warmup_steps,
            num_training_steps=num_training_steps,
        )


class DiffusionSinusoidalPosEmb(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x.unsqueeze(-1) * emb.unsqueeze(0)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class DiffusionConv1dBlock(nn.Module):
    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class DiffusionConditionalResidualBlock1d(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        cond_dim: int,
        kernel_size: int,
        n_groups: int,
        use_film_scale_modulation: bool = False,
    ):
        super().__init__()
        self.use_film_scale_modulation = use_film_scale_modulation
        self.out_channels = out_channels

        self.conv1 = DiffusionConv1dBlock(
            in_channels, out_channels, kernel_size, n_groups=n_groups
        )
        cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
        self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
        self.conv2 = DiffusionConv1dBlock(
            out_channels, out_channels, kernel_size, n_groups=n_groups
        )
        self.residual_conv = (
            nn.Conv1d(in_channels, out_channels, 1)
            if in_channels != out_channels
            else nn.Identity()
        )

    def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        out = self.conv1(x)
        cond_embed = self.cond_encoder(cond).unsqueeze(-1)
        if self.use_film_scale_modulation:
            scale = cond_embed[:, : self.out_channels]
            bias = cond_embed[:, self.out_channels:]
            out = scale * out + bias
        else:
            out = out + cond_embed
        out = self.conv2(out)
        return out + self.residual_conv(x)


class DiffusionConditionalUnet1d(nn.Module):
    def __init__(self, input_dim: int, global_cond_dim: int, config):
        super().__init__()
        self.diffusion_step_encoder = nn.Sequential(
            DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
            nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
            nn.Mish(),
            nn.Linear(
                config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim
            ),
        )
        cond_dim = config.diffusion_step_embed_dim + global_cond_dim
        in_out = [(input_dim, config.down_dims[0])] + list(
            zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
        )
        res_block_kwargs = {
            "cond_dim": cond_dim,
            "kernel_size": config.kernel_size,
            "n_groups": config.n_groups,
            "use_film_scale_modulation": config.use_film_scale_modulation,
        }
        self.down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            self.down_modules.append(
                nn.ModuleList(
                    [
                        DiffusionConditionalResidualBlock1d(dim_in, dim_out, **res_block_kwargs),
                        DiffusionConditionalResidualBlock1d(dim_out, dim_out, **res_block_kwargs),
                        nn.Conv1d(dim_out, dim_out, 3, 2, 1)
                        if not is_last
                        else nn.Identity(),
                    ]
                )
            )
        self.mid_modules = nn.ModuleList(
            [
                DiffusionConditionalResidualBlock1d(
                    config.down_dims[-1], config.down_dims[-1], **res_block_kwargs
                ),
                DiffusionConditionalResidualBlock1d(
                    config.down_dims[-1], config.down_dims[-1], **res_block_kwargs
                ),
            ]
        )
        self.up_modules = nn.ModuleList([])
        for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            self.up_modules.append(
                nn.ModuleList(
                    [
                        DiffusionConditionalResidualBlock1d(
                            dim_in * 2, dim_out, **res_block_kwargs
                        ),
                        DiffusionConditionalResidualBlock1d(
                            dim_out, dim_out, **res_block_kwargs
                        ),
                        nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
                        if not is_last
                        else nn.Identity(),
                    ]
                )
            )
        self.final_conv = nn.Sequential(
            DiffusionConv1dBlock(
                config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
            ),
            nn.Conv1d(config.down_dims[0], input_dim, 1),
        )

    def forward(self, x: torch.Tensor, timestep: torch.Tensor, global_cond=None):
        x = einops.rearrange(x, "b t d -> b d t")
        timesteps_embed = self.diffusion_step_encoder(timestep)
        global_feature = (
            torch.cat([timesteps_embed, global_cond], axis=-1)
            if global_cond is not None
            else timesteps_embed
        )
        encoder_skip_features = []
        for resnet, resnet2, downsample in self.down_modules:
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            encoder_skip_features.append(x)
            x = downsample(x)
        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)
        for resnet, resnet2, upsample in self.up_modules:
            x = torch.cat((x, encoder_skip_features.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)
        x = self.final_conv(x)
        x = einops.rearrange(x, "b d t -> b t d")
        return x
