"""
Optimization algorithms for MoE knowledge editing
"""

import torch
from typing import List, Dict, Any
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from .logger import get_logger, info, debug, warning


class BlockCoordinateDescent:
    """Block Coordinate Descent optimizer for MoE down_proj weights with historical data support"""

    def __init__(self, num_experts: int, d_model: int, d_k: int, device: str, lambda_reg: float = 1e-2):
        self.num_experts = num_experts
        self.d_model = d_model
        self.d_k = d_k
        self.device = device
        self.lambda_reg = lambda_reg

        # Historical data storage for cumulative editing
        self.historical_data = {
            'gating': [],           # List of gating tensors from previous edits
            'keys': [],             # List of keys tensors from previous edits
            'values': [],           # List of values tensors from previous edits
            'residuals': [],        # List of residual tensors from previous edits (should be zeros)
            'is_new': [],           # List of is_new flags from previous edits
            'entity_mapping': [],   # List of entity mappings from previous edits
            'prefix_mapping': []    # List of prefix mappings from previous edits
        }

        self.logger = get_logger()

        info(f"BCD Optimizer initialized with GPU acceleration on {device}")
        info(f"Experts: {num_experts}, d_model: {d_model}, d_k: {d_k}, λ: {lambda_reg}")
        info(f"Historical data support: ENABLED")

    def add_historical_data(self, stats: Dict):
        """Add current editing data to historical storage

        Args:
            stats: Current editing statistics that will become historical data
        """
        # Store current data as historical data for future edits
        self.historical_data['gating'].append(stats['gating'].clone())

        # Handle nested structure: stats['keys'] and stats['values'] are List[List[Optional[Tensor]]]
        # Clone keys
        cloned_keys = []
        for sample_keys in stats['keys']:  # Each sample's expert keys
            cloned_sample_keys = []
            for expert_key in sample_keys:  # Each expert's key (tensor or None)
                if expert_key is not None and hasattr(expert_key, 'clone'):
                    cloned_sample_keys.append(expert_key.clone())
                else:
                    cloned_sample_keys.append(expert_key)  # None or non-tensor
            cloned_keys.append(cloned_sample_keys)
        self.historical_data['keys'].append(cloned_keys)

        # Clone values
        cloned_values = []
        for sample_values in stats['values']:  # Each sample's expert values
            cloned_sample_values = []
            for expert_value in sample_values:  # Each expert's value (tensor or None)
                if expert_value is not None and hasattr(expert_value, 'clone'):
                    cloned_sample_values.append(expert_value.clone())
                else:
                    cloned_sample_values.append(expert_value)  # None or non-tensor
            cloned_values.append(cloned_sample_values)
        self.historical_data['values'].append(cloned_values)

        # For historical data, residuals should be zero (previous edits are satisfied)
        zero_residuals = [torch.zeros_like(r) for r in stats['residuals']]
        self.historical_data['residuals'].append(zero_residuals)

        self.historical_data['is_new'].append([False] * len(stats['is_new']))  # Historical data is not new
        self.historical_data['entity_mapping'].append(stats['entity_mapping'].copy())
        self.historical_data['prefix_mapping'].append(stats['prefix_mapping'].copy())

        info(f"Added {len(stats['residuals'])} samples to historical data. Total historical batches: {len(self.historical_data['gating'])}")

    def get_combined_data(self, current_stats: Dict) -> Dict:
        """Combine historical data with current editing data

        Args:
            current_stats: Current editing statistics

        Returns:
            Combined statistics with historical + current data
        """
        if not self.historical_data['gating']:
            # No historical data, return current stats as-is
            return current_stats

        # Efficiently combine historical and current data using torch operations
        if self.historical_data['gating']:
            # Concatenate all gating tensors at once (most efficient)
            combined_gating = torch.cat(self.historical_data['gating'] + [current_stats['gating']], dim=0)

            # Use itertools.chain for efficient flattening (faster than list comprehensions)
            from itertools import chain
            all_keys = list(chain.from_iterable(self.historical_data['keys'])) + current_stats['keys']
            all_values = list(chain.from_iterable(self.historical_data['values'])) + current_stats['values']
            all_residuals = list(chain.from_iterable(self.historical_data['residuals'])) + current_stats['residuals']
            all_is_new = list(chain.from_iterable(self.historical_data['is_new'])) + current_stats['is_new']
            all_entity_mapping = list(chain.from_iterable(self.historical_data['entity_mapping'])) + current_stats['entity_mapping']
            all_prefix_mapping = list(chain.from_iterable(self.historical_data['prefix_mapping'])) + current_stats['prefix_mapping']
        else:
            # No historical data, just use current
            combined_gating = current_stats['gating']
            all_keys = current_stats['keys']
            all_values = current_stats['values']
            all_residuals = current_stats['residuals']
            all_is_new = current_stats['is_new']
            all_entity_mapping = current_stats['entity_mapping']
            all_prefix_mapping = current_stats['prefix_mapping']

        combined_stats = {
            'gating': combined_gating,
            'keys': all_keys,
            'values': all_values,
            'residuals': all_residuals,
            'is_new': all_is_new,
            'entity_mapping': all_entity_mapping,
            'prefix_mapping': all_prefix_mapping
        }

        info(f"Combined data: {len(all_residuals)} total samples ({len(current_stats['residuals'])} new + {len(all_residuals) - len(current_stats['residuals'])} historical)")
        return combined_stats

    def reset_cumulative_state(self):
        """Reset historical data storage"""
        self.historical_data = {
            'gating': [],
            'keys': [],
            'values': [],
            'residuals': [],
            'is_new': [],
            'entity_mapping': [],
            'prefix_mapping': []
        }
        info("Historical data cleared")

    def save_state(self, filepath: str):
        """Save optimizer state including historical data"""
        state = {
            'num_experts': self.num_experts,
            'd_model': self.d_model,
            'd_k': self.d_k,
            'lambda_reg': self.lambda_reg,
            'historical_data': self.historical_data
        }
        torch.save(state, filepath)
        info(f"BCD state saved to {filepath}")

    def load_state(self, filepath: str):
        """Load optimizer state including historical data"""
        state = torch.load(filepath, map_location=self.device)

        # Verify compatibility
        assert state['num_experts'] == self.num_experts
        assert state['d_model'] == self.d_model
        assert state['d_k'] == self.d_k

        # Load configuration
        self.lambda_reg = state['lambda_reg']

        # Load historical data if available
        if 'historical_data' in state:
            self.historical_data = state['historical_data']
            total_historical = sum(len(batch) for batch in self.historical_data['residuals'])
            info(f"Loaded {total_historical} historical samples from {len(self.historical_data['gating'])} batches")

        info(f"BCD state loaded from {filepath}")
    
    def optimize(self, stats: Dict, num_passes: int = 2, projection_matrices: Dict = None,
             incremental: bool = False, seed: int = 42) -> List[torch.Tensor]:
        """Random-Permutation BCD (RP-BCD) optimization with optional historical data support.

        Args:
            stats: Current batch statistics (keys, values, gating, residuals).
            num_passes: Number of BCD passes over the data.
            projection_matrices: Optional expert_id -> projection matrix P_j mapping.
            incremental: Whether to include historical data in this optimization.
            seed: Random seed for reproducible expert permutation.
        """

        # Combine with historical data if incremental mode is enabled
        if incremental:
            working_stats = self.get_combined_data(stats)
            print(f"Running RP-BCD optimization with {num_passes} passes (incremental mode)...")
        else:
            working_stats = stats
            print(f"Running RP-BCD optimization with {num_passes} passes (fresh mode)...")

        # --- Step 1: Prepare tensors ---
        residuals_batch = working_stats['residuals']
        R = torch.stack([r.clone().float() for r in residuals_batch])
        T = len(residuals_batch)
        gating = working_stats['gating'].float()

        keys_tensor = torch.zeros(T, self.num_experts, self.d_k, device=self.device, dtype=torch.float32)
        values_tensor = torch.zeros(T, self.num_experts, self.d_k, device=self.device, dtype=torch.float32)
        valid_mask = torch.zeros(T, self.num_experts, device=self.device, dtype=torch.bool)

        for t in range(T):
            for j in range(self.num_experts):
                if working_stats['keys'][t][j] is not None:
                    key_tj = working_stats['keys'][t][j].float()
                    keys_tensor[t, j] = key_tj

                    if working_stats['values'][t][j] is not None:
                        values_tensor[t, j] = working_stats['values'][t][j].float()
                    elif projection_matrices and j in projection_matrices:
                        P_j = projection_matrices[j].to(self.device)
                        values_tensor[t, j] = P_j @ key_tj
                    else:
                        values_tensor[t, j] = key_tj
                    valid_mask[t, j] = True

        print(f"RP-BCD on {T} samples, {self.num_experts} experts.")

        # --- Step 2: Initialize deltas ---
        deltas = [torch.zeros(self.d_model, self.d_k, device=self.device, dtype=torch.float32)
                for _ in range(self.num_experts)]

        # Fix random seed for reproducibility
        # torch.randperm requires a CPU generator; using a CUDA generator raises:
        # RuntimeError: Expected a 'cpu' device type for generator but found 'cuda'
        g = torch.Generator(device='cpu')
        g.manual_seed(seed)

        # --- Step 3: Random-Permutation BCD ---
        for pass_idx in range(num_passes):
            # Uniformly sample a random permutation of experts (without replacement)
            perm = torch.randperm(self.num_experts, generator=g)
            # perm = list(range(self.num_experts))

            for j in perm:
                j = j.item()
                active_mask = (gating[:, j] > 0) & valid_mask[:, j]
                if not active_mask.any():
                    continue

                # Compute residual adjusted by other experts' contributions
                S_j = R.clone()
                for ell in range(self.num_experts):
                    if ell == j:
                        continue
                    contrib = gating[:, ell:ell+1] * (values_tensor[:, ell, :] @ deltas[ell].T)
                    S_j += contrib

                j_indices = torch.where(active_mask)[0]
                g_j = gating[j_indices, j]
                v_j = values_tensor[j_indices, j]
                S_j_active = S_j[j_indices]

                weighted_vs = g_j.unsqueeze(1) * v_j
                B_j = S_j_active.T @ weighted_vs
                weights = g_j ** 2
                M_j = v_j.T @ (weights.unsqueeze(1) * v_j)
                reg_term = self.lambda_reg * torch.eye(self.d_k, device=self.device, dtype=torch.float32)
                M_j_total = M_j + reg_term

                try:
                    deltas[j] = -B_j @ torch.linalg.inv(M_j_total)
                except RuntimeError:
                    deltas[j] = -B_j @ torch.linalg.pinv(M_j_total)

        print("RP-BCD optimization completed.")
        return deltas
    def _verify_optimization(self, deltas, R, gating, values_tensor, valid_mask, is_new, entity_mapping, prefix_mapping):
        """Enhanced verification with detailed analysis for prefix-expanded samples"""
        T = R.shape[0]  # Total samples (with prefix expansion)

        print(f"\n=== Optimization Verification ===")

        # Verify a few new samples
        new_indices = [i for i in range(T) if is_new[i]]
        for idx in new_indices[:4]:  # Check first 4 new samples
            # Calculate total contribution using projected values
            total_contribution = torch.zeros_like(R[idx])
            active_experts = []

            for j in range(self.num_experts):
                if gating[idx, j] > 0 and valid_mask[idx, j]:
                    g_ij = gating[idx, j]
                    v_ij = values_tensor[idx, j]  # Use projected values
                    delta_j = deltas[j]
                    expert_contribution = g_ij * (delta_j @ v_ij)
                    total_contribution += expert_contribution
                    active_experts.append(j)

            residual_i = R[idx]
            entity_i = entity_mapping[idx]
            prefix_j = prefix_mapping[idx]

            print(f"Sample {idx} (entity {entity_i}, prefix {prefix_j}):")
            print(f"  Residual[:8]: {[f'{x:.4f}' for x in residual_i[:8].tolist()]}")
            print(f"  Total contrib[:8]: {[f'{x:.4f}' for x in total_contribution[:8].tolist()]}")
            print(f"  Active experts: {len(active_experts)}")

            # Check if they're negatives of each other
            neg_cosine = F.cosine_similarity(
                total_contribution.unsqueeze(0), (-residual_i).unsqueeze(0)
            ).item()
            print(f"  Cosine with -residual: {neg_cosine:.3f}")

        print(f"Verified {min(len(new_indices), 4)} new samples out of {len(new_indices)} total new samples")
