"""
DDP-aware Calibration Buffer System for Function Encoder Pre-training

This module implements distributed calibration with:
1. Dataset calibration tasks distributed across GPUs
2. Coefficient synchronization across all GPUs
3. Optional buffer sharing across GPUs
"""

import logging
import random
from typing import Any, Dict, List, Tuple

import cvxpy as cp
import torch
import torch.distributed as dist
import torch.nn as nn
from cvxpylayers.torch import CvxpyLayer

from prismatic.models.action_heads import FunctionEncoderActionHead

logger = logging.getLogger(__name__)


class CalibrationBuffer:
    """
    Dataset-specific calibration buffer with reservoir sampling.

    Maintains a fixed-size buffer of raw TFDS samples from a specific dataset
    using reservoir sampling to ensure diversity and representativeness.
    """

    def __init__(self, buffer_size: int = 512):
        """
        Initialize calibration buffer.

        Args:
            buffer_size: Maximum number of samples to maintain in buffer
        """
        self.buffer_size = buffer_size

        # Buffer storage: List of raw TFDS samples
        self.buffer: List[Any] = []

    def add_sample(self, raw_sample: Any) -> None:
        """
        Add a raw TFDS sample to the buffer using reservoir sampling algorithm.

        Args:
            raw_sample: Raw sample from TFDS (whatever format it comes in)
        """
        if len(self.buffer) < self.buffer_size:
            # Buffer not full, just append
            self.buffer.append(raw_sample)
        else:
            # Buffer full, reservoir sampling
            # Replace a random sample with probability 1/len(samples_seen_so_far)
            # Since we don't track total samples seen, use uniform replacement
            replace_idx = random.randint(0, self.buffer_size - 1)
            self.buffer[replace_idx] = raw_sample

    def get_calibration_data(self) -> List[Any]:
        """
        Return all buffer data for FE coefficient computation.

        Returns:
            List of raw TFDS samples
        """
        if not self.buffer:
            raise ValueError("Buffer is empty. Call prefill() or add_sample() first.")

        return self.buffer.copy()

    def __len__(self) -> int:
        """Return current buffer size."""
        return len(self.buffer)

    def __repr__(self) -> str:
        return f"CalibrationBuffer(buffer_size={self.buffer_size}, current_size={len(self.buffer)})"


class CalibrationManager:
    """
    DDP-aware Multi-dataset calibration buffer manager.

    Manages calibration buffers for multiple datasets and coordinates
    periodic recalibration of Function Encoder coefficients in distributed training.

    Key features:
    - Each node only stores buffers for datasets assigned to it (ds_id % n_node)
    - Routes samples to correct nodes during training
    - Each node calibrates only its local buffers
    - Broadcasts coefficients to all nodes after calibration
    """

    def __init__(self, dataset_names: List[str], buffer_size: int = 512):
        """
        Initialize DDP calibration manager.

        Args:
            dataset_names: Complete list of dataset names (order determines assignment)
            calibrate_interval: Number of steps between recalibrations
            buffer_size: Size of each dataset's calibration buffer
        """
        self.buffer_size = buffer_size

        # DDP info
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
        self.rank = dist.get_rank() if dist.is_initialized() else 0
        self.is_distributed = dist.is_initialized()
        self.disable_dims = None
        self.eval_mode = False

        # Store dataset list and create ID mapping
        self.dataset_names = dataset_names
        self.dataset_id_map = {name: i for i, name in enumerate(dataset_names)}

        # Determine which datasets this node owns (ds_id % n_nodes == rank)
        self.my_datasets = set()
        for ds_name, ds_id in self.dataset_id_map.items():
            if ds_id % self.world_size == self.rank:
                self.my_datasets.add(ds_name)

        # Dataset name -> CalibrationBuffer mapping (only for datasets assigned to this node)
        self.buffers: Dict[str, CalibrationBuffer] = {}
        for ds_name in self.my_datasets:
            self.buffers[ds_name] = CalibrationBuffer(buffer_size=buffer_size)

        # Cache for samples to route to other nodes
        self.pending_samples_for_other_nodes: List[Tuple[Any, str]] = []

    def _get_node_for_dataset(self, dataset_name: str) -> int:
        """
        Determine which node owns a dataset's buffer.

        Args:
            dataset_name: Name of the dataset

        Returns:
            Node rank that owns this dataset's buffer
        """
        dataset_id = self.dataset_id_map[dataset_name]
        return dataset_id % self.world_size

    def add_training_sample(self, raw_sample: Any, dataset_name: str) -> None:
        """
        Add new raw training sample to appropriate buffer.
        Routes to correct node if necessary.

        Args:
            raw_sample: Raw TFDS sample
            dataset_name: Name of the source dataset
        """
        target_node = self._get_node_for_dataset(dataset_name)

        if target_node == self.rank:
            # This node owns the buffer, add directly
            self.buffers[dataset_name].add_sample(raw_sample)
        else:
            # Queue for routing to another node
            self.pending_samples_for_other_nodes.append((raw_sample, dataset_name))

    def route_pending_samples(self) -> None:
        """
        Route pending samples to their target nodes using object-based point-to-point communication.
        Should be called periodically to flush pending samples.
        """
        if not self.is_distributed:
            return

        # Group pending samples by target rank
        samples_by_rank = [[] for _ in range(self.world_size)]
        for raw_sample, dataset_name in self.pending_samples_for_other_nodes:
            target_rank = self._get_node_for_dataset(dataset_name)
            samples_by_rank[target_rank].append((raw_sample, dataset_name))

        for i in range(self.world_size):
            samples_by_rank[i] = samples_by_rank[i][-8:]

        # Deadlock-free round-robin pairwise schedule (circle method).
        # Builds symmetric pairs each round so both endpoints agree on partners.
        # If world size is odd, a dummy rank (-1) is introduced and pairs involving it are skipped.
        n = self.world_size if (self.world_size % 2 == 0) else self.world_size + 1
        players = list(range(self.world_size)) + ([-1] if n > self.world_size else [])
        for _ in range(n - 1):
            # Form pairs for this round
            partner = None
            for j in range(n // 2):
                a = players[j]
                b = players[n - 1 - j]
                if a == self.rank:
                    partner = b
                    break
                if b == self.rank:
                    partner = a
                    break

            # Communicate with partner if valid
            if partner is not None and partner != -1:
                outbound = [samples_by_rank[partner]]
                inbound = [None]

                # Order within the pair: lower rank sends first to avoid deadlock
                if self.rank < partner:
                    dist.send_object_list(outbound, dst=partner)

                    dist.recv_object_list(inbound, src=partner)
                else:
                    dist.recv_object_list(inbound, src=partner)

                    dist.send_object_list(outbound, dst=partner)

                if inbound[0]:
                    for raw_sample, dataset_name in inbound[0]:
                        if dataset_name in self.buffers:
                            self.buffers[dataset_name].add_sample(raw_sample)

            # Rotate players (keep first fixed, rotate the rest)
            if n > 1:
                players = [players[0], players[-1], *players[1:-1]]

        # Clear pending samples after routing
        self.pending_samples_for_other_nodes = []

    def _calibrate_single_dataset(
        self, dataset_name: str, fe_action_head: FunctionEncoderActionHead, vla: nn.Module, collator
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Calibrate FE coefficients for a single dataset.

        Args:
            dataset_name: Name of the dataset to calibrate
            fe_action_head: Function Encoder action head to calibrate
            vla: Full VLA model to process raw samples
            collator: Data collator to batch raw samples

        Returns:
            Tuple of (l1_coefficients, l2_coefficients)
        """
        assert dataset_name in self.buffers, f"No buffer for dataset {dataset_name}"

        buffer = self.buffers[dataset_name]
        assert len(buffer) > 0, f"Buffer for dataset {dataset_name} is empty"

        # Get raw samples from buffer
        raw_samples = buffer.get_calibration_data()

        # Batch the raw samples using the collator
        batched_samples = collator(raw_samples)

        # Move to model's device
        device = next(vla.parameters()).device
        for key in batched_samples:
            if isinstance(batched_samples[key], torch.Tensor):
                batched_samples[key] = batched_samples[key].to(device)

        # Get required data
        pixel_values = batched_samples.get("pixel_values")
        input_ids = batched_samples.get("input_ids")
        attention_mask = batched_samples.get("attention_mask")
        labels = batched_samples.get("labels")
        actions = batched_samples.get("actions")

        if pixel_values is None or actions is None:
            raise ValueError("Missing required fields in batched samples")

        # Process through VLA model to get hidden states in mini-batches
        with torch.no_grad():
            vla.eval()

            # Mini-batch size for VLA forward pass (adjust based on VRAM)
            mini_batch_size = 32  # Adjust this based on your VRAM constraints

            batch_size = input_ids.shape[0]
            all_actions_hidden_states = []

            # Import required modules
            from prismatic.training.train_utils import get_current_action_mask, get_next_actions_mask
            from prismatic.vla.constants import ACTION_DIM, NUM_ACTIONS_CHUNK

            # Process in mini-batches
            for start_idx in range(0, batch_size, mini_batch_size):
                end_idx = min(start_idx + mini_batch_size, batch_size)

                # Get mini-batch data
                mini_input_ids = input_ids[start_idx:end_idx]
                mini_attention_mask = attention_mask[start_idx:end_idx]
                mini_pixel_values = pixel_values[start_idx:end_idx]
                mini_labels = labels[start_idx:end_idx]

                # Forward pass through VLA for mini-batch
                outputs = vla(
                    input_ids=mini_input_ids,
                    attention_mask=mini_attention_mask,
                    pixel_values=mini_pixel_values.to(torch.bfloat16),
                    labels=mini_labels,
                    output_hidden_states=True,
                    return_dict=True,
                )

                # Get last hidden states
                hidden_states = outputs.hidden_states[-1]  # (mini_batch_size, seq_len, hidden_dim)

                # Get masks for action tokens
                ground_truth_token_ids = mini_labels[:, 1:]  # Shift for causal LM
                current_action_mask = get_current_action_mask(ground_truth_token_ids)
                next_actions_mask = get_next_actions_mask(ground_truth_token_ids)

                # For vision models, we need to account for vision patches
                if hasattr(vla, "vision_backbone"):
                    vla_module = vla.module if hasattr(vla, "module") else vla
                    num_patches = (
                        vla_module.vision_backbone.get_num_patches()
                        * vla_module.vision_backbone.get_num_images_in_input()
                    )
                    # Get text hidden states (after vision patches)
                    text_hidden_states = hidden_states[:, num_patches:-1]
                else:
                    text_hidden_states = hidden_states[:, :-1]  # Remove last token

                # Extract action hidden states for this mini-batch
                mini_batch_size_actual = mini_input_ids.shape[0]
                actions_hidden_states = text_hidden_states[current_action_mask | next_actions_mask]
                actions_hidden_states = actions_hidden_states.reshape(
                    mini_batch_size_actual, NUM_ACTIONS_CHUNK * ACTION_DIM, -1
                )

                all_actions_hidden_states.append(actions_hidden_states)

            # Concatenate all mini-batch results
            actions_hidden_states = torch.cat(all_actions_hidden_states, dim=0)

        # Now we need to get basis functions from the FE head
        with torch.no_grad():
            basis_functions, _ = fe_action_head.forward_basis_functions(actions_hidden_states)

            # Reshape ground truth actions
            gt_actions = actions.reshape(batch_size, NUM_ACTIONS_CHUNK, ACTION_DIM)

            if self.disable_dims is not None:
                for dim in self.disable_dims:
                    gt_actions[..., dim] = 0
                    basis_functions[..., dim] = 0

            # Flatten for calibration
            basis_functions_flat = basis_functions.reshape(-1, fe_action_head.k, ACTION_DIM)
            gt_actions_flat = gt_actions.reshape(-1, ACTION_DIM)

            # Split into continuous and discrete actions
            n_continuous = fe_action_head.n_continuous
            n_discrete = ACTION_DIM - n_continuous

            continuous_basis = basis_functions_flat[..., :n_continuous]  # (N, k, n_continuous)
            discrete_basis = basis_functions_flat[..., n_continuous:] if n_discrete > 0 else None  # (N, k, n_discrete)

            gt_continuous = gt_actions_flat[..., :n_continuous]  # (N, n_continuous)
            gt_discrete = gt_actions_flat[..., n_continuous:] if n_discrete > 0 else None  # (N, n_discrete)

            n_samples = continuous_basis.shape[0]
            k = fe_action_head.k
            v = n_continuous

            # Create optimization variables
            x = cp.Variable(k)  # coefficients
            r = cp.Variable(n_samples * v)  # residuals

            # Create parameters
            A_param = cp.Parameter((n_samples * v, k))
            b_param = cp.Parameter(n_samples * v)

            # Objective: minimize sum of absolute residuals (L1 loss)
            objective = cp.Minimize(cp.sum(r))

            # Constraints
            residual = A_param @ x - b_param
            if self.eval_mode:
                constraints = [r >= residual, r >= -residual]
            else:
                constraints = [r >= residual, r >= -residual, x <= 1, x >= -1]

            # Create problem
            problem = cp.Problem(objective, constraints)
            layer = CvxpyLayer(problem, parameters=[A_param, b_param], variables=[x])

            # Prepare data for CVX layer
            A_flat = continuous_basis.permute(0, 2, 1).reshape(n_samples * v, k)
            b_flat = gt_continuous.reshape(n_samples * v)

            # Solve for L1 coefficients
            (l1_coefficients,) = layer(A_flat.to(torch.float32), b_flat.to(torch.float32))

            # === Compute L1 coefficients for discrete actions ===
            l2_coefficients = torch.zeros(k, device=device, dtype=torch.float32)
            if n_discrete > 0 and discrete_basis is not None:
                # Use raw discrete ground-truth targets (0/1)

                # Dimensions for discrete solve
                vd = n_discrete

                # Create optimization variables for discrete L1 regression
                xd = cp.Variable(k)  # coefficients
                rd = cp.Variable(n_samples * vd)  # residuals

                # Parameters
                Ad_param = cp.Parameter((n_samples * vd, k))
                bd_param = cp.Parameter(n_samples * vd)

                # Objective and constraints: minimize L1 loss
                objective_d = cp.Minimize(cp.sum(rd))
                residual_d = Ad_param @ xd - bd_param
                if self.eval_mode:
                    constraints_d = [rd >= residual_d, rd >= -residual_d]
                else:
                    constraints_d = [rd >= residual_d, rd >= -residual_d, xd <= 1, xd >= -1]

                # Problem and layer
                problem_d = cp.Problem(objective_d, constraints_d)
                layer_d = CvxpyLayer(problem_d, parameters=[Ad_param, bd_param], variables=[xd])

                # Prepare data
                A_flat_d = discrete_basis.permute(0, 2, 1).reshape(n_samples * vd, k)
                b_flat_d = gt_discrete.reshape(n_samples * vd)

                # Solve
                (l2_coefficients,) = layer_d(A_flat_d.to(torch.float32), b_flat_d.to(torch.float32))

        # logger.info(
        #     f"[Rank {self.rank}] Calibrated {dataset_name} - "
        #     f"L1 norm: {l1_coefficients.norm():.4f}, L2 norm: {l2_coefficients.norm():.4f}"
        # )

        DEBUG_MODE = False
        if DEBUG_MODE:
            # Print the ground truth actions and the estimated actions
            import matplotlib.pyplot as plt

            print(f"Ground truth actions: {gt_actions_flat}")  # (N, v)
            estimated_continuous = (continuous_basis * l1_coefficients.to(torch.float32).reshape(1, -1, 1)).sum(dim=1)
            estimated_discrete = (
                (discrete_basis * l2_coefficients.to(torch.float32).reshape(1, -1, 1)).sum(dim=1)
                if n_discrete > 0 and discrete_basis is not None
                else torch.empty(n_samples, 0, device=device)
            )
            estimated_actions = torch.cat([estimated_continuous, estimated_discrete], dim=-1)
            print(f"Estimated actions: {estimated_actions}")
            # plot the ground truth actions and the estimated actions, one curve for each action dimension in a separate plot
            gt_actions_flat = gt_actions_flat.cpu().numpy()
            estimated_actions = estimated_actions.cpu().numpy()
            _, axes = plt.subplots(1, 7, figsize=(16, 4))
            i = 0
            for dim in range(7):
                axes[i].plot(gt_actions_flat[:, dim], label="Ground truth")
                axes[i].plot(estimated_actions[:, dim], label="Estimated")
                axes[i].legend()
                i += 1
            plt.tight_layout()
            plt.savefig(f"ground_truth_and_estimated_actions_{dataset_name}.png")
            exit(0)

        return l1_coefficients, l2_coefficients

    def calibrate_all_datasets(self, fe_action_head: FunctionEncoderActionHead, vla: nn.Module, collator) -> None:
        """
        Calibrate datasets on this node and broadcast coefficients to all nodes.

        Each node only calibrates the datasets it owns (based on ds_id % n_nodes),
        then all coefficients are gathered and broadcast to all nodes.

        Args:
            fe_action_head: Function Encoder action head to calibrate
            vla: Full VLA model to process raw samples
            collator: Data collator to batch raw samples
        """
        # First, route any pending samples
        self.route_pending_samples()

        # Unwrap DDP module if needed
        fe_head_module = fe_action_head.module if hasattr(fe_action_head, "module") else fe_action_head
        vla_module = vla.module if hasattr(vla, "module") else vla

        dataset_coefficients = {}
        for dataset_name in self.my_datasets:
            l1_coef, l2_coef = self._calibrate_single_dataset(dataset_name, fe_head_module, vla_module, collator)
            dataset_coefficients[dataset_name] = {
                "l1": l1_coef.to(torch.bfloat16),
                "l2": l2_coef.to(torch.bfloat16),
            }

        if self.is_distributed:
            all_coefficients_list = [None] * self.world_size
            dist.all_gather_object(all_coefficients_list, dataset_coefficients)

            # Merge all coefficients
            merged_coefficients = {}
            for rank_coefficients in all_coefficients_list:
                if rank_coefficients:
                    merged_coefficients.update(rank_coefficients)

            # Set all coefficients on this node
            for dataset_name, coeffs in merged_coefficients.items():
                fe_head_module.set_dataset_coefficients(
                    dataset_name,
                    coeffs["l1"].to(next(fe_head_module.parameters()).device),
                    coeffs["l2"].to(next(fe_head_module.parameters()).device),
                )
        else:
            # Single GPU - just set the coefficients directly
            for dataset_name, coeffs in dataset_coefficients.items():
                fe_head_module.set_dataset_coefficients(dataset_name, coeffs["l1"], coeffs["l2"])

        # Ensure all processes are synchronized
        if self.is_distributed:
            dist.barrier()

    def __repr__(self) -> str:
        return (
            f"CalibrationManager("
            f"my_buffers={len(self.buffers)}/{len(self.dataset_names)}, "
            f"rank={self.rank}/{self.world_size})"
        )
