import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from diffusers.optimization import get_scheduler
from huggingface_hub import PyTorchModelHubMixin

from flare.factory import registry
from flare.policies import BasePolicy
from flare.models.unet.unet import ConditionalUnet1D
from flare.models.flow_transformer.flow import FlowTransformer
from flare.flow.flow_matchers import get_flow_matcher
from flare.policies.observers.resnet_observer import ResNetObserver
from flare.policies.observers.dino_observer import Dino2DObserver
from flare.utils.normalize import Normalize, Unnormalize
from flare.visualizer.visualizer import plot_trajectory, plot_ode_steps


@registry.register_policy("flow")
class FlowPolicy(BasePolicy):
    def __init__(self, config, stats):
        super().__init__(config, stats)

        self.config = config
        self.stats = stats
        self.num_sampling_steps = config.policy.flow_matcher.num_sampling_steps
        self.pred_horizon = config.policy.pred_horizon
        self.action_horizon = config.policy.action_horizon
        self.action_dim = config.task.action_dim

        self.normalize_inputs = Normalize(config.task.image_keys+[config.task.state_key], stats)
        self.normalize_targets = Normalize([config.task.action_key], stats)
        self.unnormalize_outputs = Unnormalize([config.task.action_key], stats)
        self._action_queue = None
        obs_config = config.policy.observer
        if obs_config.name == 'resnet18':
            self.observer = ResNetObserver(
                state_key=config.task.state_key,
                image_keys=config.task.image_keys,
                resize_shape=config.resize_shape,
                crop_shape=config.crop_shape,
                state_dim=config.task.state_dim,
                tokenize=obs_config.tokenize,
            )
            self.obs_dim = len(config.task.image_keys) * 512 + config.task.state_dim
        elif obs_config.name == 'dinov2':
            self.observer = Dino2DObserver(
                state_key=config.task.state_key,
                image_keys=config.task.image_keys,
                resize_shape=config.resize_shape,
                crop_shape=config.crop_shape,
                freeze_backbone=obs_config.freeze_backbone,
                state_dim=config.task.state_dim,
                q_former_dim_model=obs_config.q_former.dim_model,
                q_former_pool_n_tokens=obs_config.q_former.pool_n_tokens,
                q_former_pool_n_layers=obs_config.q_former.pool_n_layers,
                q_former_n_heads=obs_config.q_former.n_heads,
                q_former_dropout=obs_config.q_former.dropout,
                q_former_mlp_ratio=obs_config.q_former.mlp_ratio,
            )
            self.obs_dim = obs_config.q_former.dim_model
        else:
            raise ValueError(f"Observer {obs_config.name} is not supported.")

        self.FM = get_flow_matcher(**config.policy.flow_matcher)
        self.flow_net = self._init_flow_net(condition_dim=self.obs_dim)

        self.reset()

    def _init_flow_net(self, condition_dim):
        """
        Initialize the velocity prediction network
        """
        if self.config.policy.flow_net.name == 'unet':
            return ConditionalUnet1D(
                input_dim=self.config.task.action_dim,
                global_cond_dim=condition_dim,
            )
        elif self.config.policy.flow_net.name == 'flow_transformer':
            return FlowTransformer(
                input_dim=self.config.task.action_dim,
                condition_dim=condition_dim,
                output_dim=self.config.task.action_dim,
                hidden_dim=self.config.policy.flow_net.hidden_dim,
                num_layers=self.config.policy.flow_net.num_layers,
                num_heads=self.config.policy.flow_net.num_heads,
                block_type=self.config.policy.flow_net.block_type
            )

    def compute_loss(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
        features = self.observer(batch)
        target = batch[self.config.task.action_key]
        loss, metrics = self.FM.compute_loss(self.flow_net, target=target, cond=features)
        metrics['flow_loss'] = loss.item()
        return loss, metrics

    def generate_actions(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
        batch_size = batch[self.config.task.state_key].shape[0]
        features = self.observer(batch)
        actions = self.FM.sample(
            self.flow_net,
            (batch_size, self.pred_horizon, self.action_dim),
            features.device,
            self.num_sampling_steps,
            cond=features,
            return_traces=False
        )

        return actions

    @torch.no_grad
    def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
        self.eval()
        batch = {k: v.unsqueeze(1) for k, v in batch.items() if k in self.config.task.image_keys + [self.config.task.state_key]}
        batch = self.normalize_inputs(batch)
        if len(self._action_queue) == 0:
            actions = self.generate_actions(batch)
            actions = actions[:, :self.action_horizon]
            actions = self.unnormalize_outputs({"action": actions})["action"]
            self._action_queue.extend(actions.transpose(0, 1))
        return self._action_queue.popleft()

    def forward(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, None]:
        loss, metrics = super().forward(batch)
        metrics['flow_loss'] = loss.item()
        return loss, metrics

    def get_optimizer(self) -> torch.optim.Optimizer:
        return torch.optim.AdamW(
            params=self.parameters(),
            lr=self.config.optimizer_lr,
            betas=self.config.optimizer_betas,
            eps=self.config.optimizer_eps,
            weight_decay=self.config.optimizer_weight_decay
        )

    def get_scheduler(self, optimizer: torch.optim.Optimizer, num_training_steps: int) -> torch.optim.lr_scheduler.LambdaLR | None:
        return get_scheduler(
            name=self.config.scheduler_name,
            optimizer=optimizer,
            num_warmup_steps=self.config.scheduler_warmup_steps,
            num_training_steps=num_training_steps,
        )

    def visualize(self, batch: dict[str, torch.Tensor], num_samples: int = 1) -> dict[str, plt.Figure]:
        self.eval()
        for key in batch:
            batch[key] = batch[key][:num_samples]
        batch = self.normalize_inputs(batch)
        device = batch[self.config.task.action_key].device

        with torch.no_grad():
            pred_norm = self.generate_actions(batch)
            pred = self.unnormalize_outputs({"action": pred_norm})["action"]
        gt = batch[self.config.task.action_key]

        features = self.observer(batch)

        _, (traj_hist, _) = self.FM.sample(
            self.flow_net,
            (num_samples, self.pred_horizon, self.action_dim),
            device,
            self.num_sampling_steps,
            cond=features,
            return_traces=True
        )

        viz: dict[str, plt.Figure] = {}
        for i in range(num_samples):
            # --- Fig 1: GT vs Pred ---
            fig1, ax1 = plt.subplots()
            traj_pred = pred[i, :, :2].cpu().numpy()
            traj_gt = gt[i, :, :2].cpu().numpy()
            plot_trajectory(ax=ax1, pred=traj_pred, target=traj_gt)
            viz[f"cmp_{i}"] = fig1

            # --- Fig 2: ODE denoising steps ---
            traj_steps = [step[i].cpu().numpy() for step in traj_hist]
            fig2 = plot_ode_steps(traj_steps)
            viz[f"denoise_{i}"] = fig2

        return viz
