import logging
from collections import deque

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

from flare.factory import registry
from flare.flow.flow_matchers import get_flow_matcher
from flare.policies import BasePolicy
from flare.policies.observers.resnet_observer import ResNetObserver
from flare.policies.vita.transformer_flow_net import TokenTransformerFlowNet
from flare.networks.vita.action_ae import get_autoencoder
from flare.utils.normalize import Normalize, Unnormalize
from flare.visualizer.visualizer import plot_trajectory, plot_ode_steps


logger = logging.getLogger(__name__)


def compute_contrastive_loss(image_features, action_features, temperature=0.07):
    # Normalize features
    batch_size = image_features.size(0)
    image_features = F.normalize(image_features, dim=1)
    action_features = F.normalize(action_features, dim=1)

    # Compute similarity matrix
    logits = torch.matmul(image_features, action_features.T) / temperature

    # Symmetric contrastive loss (image-to-action + action-to-image)
    labels = torch.arange(batch_size, device=logits.device)
    loss_i2a = F.cross_entropy(logits, labels)
    loss_a2i = F.cross_entropy(logits.T, labels)

    return (loss_i2a + loss_a2i) / 2


@registry.register_policy("vita_token")
class VitaTokenPolicy(BasePolicy):

    def __init__(self, config, stats):
        super().__init__(config, stats)

        self.config = config
        self.stats = stats

        self.num_sampling_steps = config.policy.flow_matcher.num_sampling_steps
        self.action_horizon = config.policy.action_horizon
        self.action_dim = config.task.action_dim
        self.obs_horizon = config.policy.obs_horizon

        self.normalize_inputs = Normalize(config.task.image_keys + [config.task.state_key], stats)
        self.normalize_targets = Normalize([config.task.action_key], stats)
        self.unnormalize_outputs = Unnormalize([config.task.action_key], stats)
        self._action_queue = None

        # --- Observer: tokenized ResNet features ---
        self.observer = ResNetObserver(
            state_key=config.task.state_key,
            image_keys=config.task.image_keys,
            resize_shape=config.resize_shape,
            crop_shape=config.crop_shape,
            state_dim=config.task.state_dim,
            tokenize=True,
        )

        # Token geometry: ResNetObserver with 3x3 pooling -> 9 tokens per image per step
        num_image_tokens_per_step = 9 * len(config.task.image_keys)
        self.num_tokens = 1 + self.obs_horizon * num_image_tokens_per_step  # +1 for state token

        # Channel dimension for token latents
        self.token_dim = config.policy.vita.token_dim

        # Project raw ResNet tokens (512) to token_dim with normalization
        self.obs_token_proj = nn.Sequential(
            nn.Linear(512, self.token_dim),
            # nn.LayerNorm(self.token_dim),
        )

        # Flow matcher
        self.FM = get_flow_matcher(**config.policy.flow_matcher)

        # --- Action autoencoder on token latents ---
        action_ae_net_config = config.policy.action_ae.net
        self.flow_action_recon_weight = config.policy.action_ae.flow_recon_weight
        self.enc_action_recon_weight = config.policy.action_ae.enc_recon_weight

        recon_loss_type = config.policy.action_ae.recon_loss_type
        self.action_kl_weight = config.policy.action_ae.kl_weight
        if recon_loss_type == "l1":
            self.recon_loss_fn = F.l1_loss
        elif recon_loss_type == "l2":
            self.recon_loss_fn = F.mse_loss
        else:
            raise ValueError(f"Unsupported recon_loss_type: {recon_loss_type}. Use 'l1' or 'l2'.")

        # Contrastive / consistency weights
        self.enc_contrastive_weight = config.policy.vita.enc_contrastive_weight
        self.flow_contrastive_weight = config.policy.vita.flow_contrastive_weight
        self.decode_flow_latents = config.policy.vita.decode_flow_latents
        self.consistency_weight = config.policy.vita.consistency_weight

        # Initialize action encoder/decoder
        self._init_action_ae(action_ae_net_config)

        # --- Flow network over token latents ---
        self.flow_net = TokenTransformerFlowNet(
            token_dim=self.token_dim,
            hidden_dim=config.policy.flow_net.hidden_dim,
            num_layers=config.policy.flow_net.num_layers,
            num_heads=config.policy.flow_net.num_heads,
            mlp_ratio=config.policy.flow_net.mlp_ratio,
            dropout=config.policy.flow_net.dropout,
            max_seq_len=self.num_tokens,
        )

        # Print number of all the parameters except the action encoder
        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        encoder_params = sum(p.numel() for p in self.action_encoder.parameters() if p.requires_grad)
        decoder_params = sum(p.numel() for p in self.action_decoder.parameters() if p.requires_grad)
        flow_net_params = sum(p.numel() for p in self.flow_net.parameters() if p.requires_grad)
        logger.info(f"VitaTokenPolicy: {total_params - encoder_params} params.")
        logger.info(f" - Action encoder: {encoder_params} | decoder: {decoder_params} | flow net: {flow_net_params} params.")

        self.reset()

    def _init_action_ae(self, action_ae_net_config):
        logger.info(
            "Using token CNN encoder (%s) and decoder (%s) for action autoencoder.",
            action_ae_net_config.encoder_type,
            action_ae_net_config.decoder_type,
        )
        self.action_encoder, self.action_decoder = get_autoencoder(self.config, self.config.policy)

    def compute_loss(self, batch: dict[str, torch.Tensor]):
        # --- Observation tokens ---
        obs_tokens_raw = self.observer(batch)  # (B, num_tokens, 512)
        obs_tokens = self.obs_token_proj(obs_tokens_raw)  # (B, num_tokens, token_dim)

        batch_size = obs_tokens.shape[0]
        gt_actions = batch[self.config.task.action_key]

        action_tokens = self.action_encoder(gt_actions)  # (B, num_tokens, token_dim)
        action_posterior = None

        # Flow matching in token space
        flow_loss, metrics = self.FM.compute_loss(
            self.flow_net,
            target=action_tokens,
            start=obs_tokens,
        )
        loss = flow_loss
        metrics["flow_loss"] = flow_loss.item()

        # Contrastive loss between obs and action tokens
        if self.enc_contrastive_weight > 0:
            image_features = obs_tokens.reshape(batch_size, -1)
            action_features = action_tokens.reshape(batch_size, -1)
            contrastive_loss = compute_contrastive_loss(image_features, action_features)
            loss += self.enc_contrastive_weight * contrastive_loss
            metrics["enc_contrastive_loss"] = contrastive_loss.item()

        # --- Decode flow latents for reconstruction / consistency ---
        if self.decode_flow_latents:
            action_tokens_pred = self.FM.sample(
                self.flow_net,
                shape=(batch_size, self.num_tokens, self.token_dim),
                device=obs_tokens.device,
                start=obs_tokens,
                num_steps=self.num_sampling_steps,
            )

            if self.consistency_weight > 0:
                consistency_loss = F.mse_loss(action_tokens_pred, action_tokens)
                loss += self.consistency_weight * consistency_loss
                metrics["consistency_loss"] = consistency_loss.item()

            if self.flow_contrastive_weight > 0:
                image_features = obs_tokens.view(batch_size, -1)
                action_features = action_tokens_pred.view(batch_size, -1)
                contrastive_loss = compute_contrastive_loss(image_features, action_features)
                loss += self.flow_contrastive_weight * contrastive_loss
                metrics["flow_contrastive_loss"] = contrastive_loss.item()

            if self.flow_action_recon_weight > 0:
                actions_recon = self.action_decoder(action_tokens_pred)
                action_recon_loss = self.recon_loss_fn(actions_recon, gt_actions)
                metrics["flow_action_recon_loss"] = action_recon_loss.item()
                loss += self.flow_action_recon_weight * action_recon_loss
        else:
            action_tokens_pred = action_tokens

        # Encoder reconstruction loss
        if self.enc_action_recon_weight > 0:
            actions_recon = self.action_decoder(action_tokens)
            action_recon_loss = self.recon_loss_fn(actions_recon, gt_actions)
            metrics["enc_action_recon_loss"] = action_recon_loss.item()
            loss += self.enc_action_recon_weight * action_recon_loss

        return loss, metrics

    def generate_actions(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
        batch_size = batch[self.config.task.state_key].shape[0]

        obs_tokens_raw = self.observer(batch)
        obs_tokens = self.obs_token_proj(obs_tokens_raw)

        action_tokens_pred = self.FM.sample(
            self.flow_net,
            shape=(batch_size, self.num_tokens, self.token_dim),
            device=obs_tokens.device,
            num_steps=self.num_sampling_steps,
            start=obs_tokens,
            return_traces=False,
        )

        actions_pred = self.action_decoder(action_tokens_pred)

        return actions_pred

    @torch.no_grad
    def select_action(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
        self.eval()
        batch = {
            k: v.unsqueeze(1)
            for k, v in batch.items()
            if k in self.config.task.image_keys + [self.config.task.state_key]
        }
        batch = self.normalize_inputs(batch)
        if len(self._action_queue) == 0:
            actions = self.generate_actions(batch)
            actions = actions[:, :self.action_horizon]
            actions = self.unnormalize_outputs({"action": actions})["action"]
            self._action_queue = deque(actions.transpose(0, 1), maxlen=self.action_horizon)
        return self._action_queue.popleft()

    def get_optimizer(self) -> torch.optim.Optimizer:
        return torch.optim.AdamW(
            params=self.parameters(),
            lr=self.config.optimizer_lr,
            betas=self.config.optimizer_betas,
            eps=self.config.optimizer_eps,
            weight_decay=self.config.optimizer_weight_decay,
        )

    def get_scheduler(
        self,
        optimizer: torch.optim.Optimizer,
        num_training_steps: int,
    ) -> torch.optim.lr_scheduler.LambdaLR | None:
        from diffusers.optimization import get_scheduler

        return get_scheduler(
            name=self.config.scheduler_name,
            optimizer=optimizer,
            num_warmup_steps=self.config.scheduler_warmup_steps,
            num_training_steps=num_training_steps,
        )

    @torch.no_grad
    def visualize(self, batch: dict[str, torch.Tensor], num_samples: int = 1) -> dict[str, plt.Figure]:
        self.eval()
        for key in batch:
            batch[key] = batch[key][:num_samples]
        batch = self.normalize_inputs(batch)
        device = batch[self.config.task.action_key].device

        with torch.no_grad():
            pred_norm = self.generate_actions(batch)
            pred = self.unnormalize_outputs({"action": pred_norm})["action"]
        gt = batch[self.config.task.action_key]

        obs_tokens_raw = self.observer(batch)
        obs_tokens = self.obs_token_proj(obs_tokens_raw)

        action_tokens, (latents_hist, _) = self.FM.sample(
            self.flow_net,
            shape=(num_samples, self.num_tokens, self.token_dim),
            device=device,
            num_steps=self.num_sampling_steps,
            start=obs_tokens,
            return_traces=True,
        )

        viz: dict[str, plt.Figure] = {}
        for i in range(num_samples):
            fig1, ax1 = plt.subplots()
            traj_pred = pred[i, :, :2].cpu().numpy()
            traj_gt = gt[i, :, :2].cpu().numpy()
            plot_trajectory(ax=ax1, pred=traj_pred, target=traj_gt)
            viz[f"cmp_{i}"] = fig1

            traj_actions = []
            for lh in latents_hist:
                lat_i = lh[i]
                lat_i = lat_i.to(device)
                act_traj = self.action_decoder(lat_i.unsqueeze(0)).squeeze(0).cpu().numpy()
                traj_actions.append(act_traj)

            fig2 = plot_ode_steps(traj_actions)
            viz[f"denoise_{i}"] = fig2

        return viz
