import numpy as np
import torch
import torch.nn as nn
from loguru import logger
from tqdm import tqdm
from sklearn.cluster import MiniBatchKMeans, KMeans
from sklearn.decomposition import PCA
from typing import Tuple, List, Optional, Dict
from ..base_mapper import BaseMoEMapper
from ..tree_structure import BottomUpHierarchyTree, TreeNode
from ..core.mlp import SimpleLinearModel, SimpleLinearMapper
from .lora_config import LoRAConfig
from .lora_expert import LoRAExpert


class HierarchicalLoRAMoEMapper(BaseMoEMapper):
    def __init__(
        self,
        num_levels: int = 3,
        branch_factor: int = 4,
        lora_rank: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.1,
        share_base_model: bool = True,
        base_model_epochs: int = 200,
        lora_epochs: int = 100,
        lora_learning_rate: float = 1e-3,
        mapper_config = None,
        distance_metric: str = "cosine",
        transform_strategy: str = "cluster_then_route",
        enable_mixing: bool = False,
        enable_chaining: bool = False,
        alpha: float = 0.5,
        beta: float = 0.7,
        **kwargs
    ):
        super().__init__()
        self.num_levels = num_levels
        self.branch_factor = branch_factor
        self.mapper_config = mapper_config
        self.distance_metric = distance_metric
        self.share_base_model = share_base_model
        self.num_leaf_clusters = branch_factor ** (num_levels - 1)
        self.lora_config = LoRAConfig(
            rank=lora_rank,
            alpha=lora_alpha,
            dropout=lora_dropout
        )
        self.base_model_epochs = base_model_epochs
        self.lora_epochs = lora_epochs
        self.lora_learning_rate = lora_learning_rate
        self.enable_mixing = enable_mixing
        self.enable_chaining = enable_chaining
        self.alpha = alpha
        self.beta = beta
        self.tree: Optional[BottomUpHierarchyTree] = None
        self.train_loader = None
        self.num_samples: Optional[int] = None
        self.input_dim: Optional[int] = None
        self.output_dim: Optional[int] = None
        self.sample_to_leaf: Optional[np.ndarray] = None
        self.leaf_to_indices: Optional[List[np.ndarray]] = None
        self.base_model: Optional[nn.Module] = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.lora_adapters: Dict[int, LoRAExpert] = {}
        self.u1: Optional[torch.Tensor] = None
        self.u2: Optional[torch.Tensor] = None
        logger.info(
            f"Initialized HierarchicalLoRAMoEMapper: {self.num_levels} levels, "
            f"branch_factor={self.branch_factor}, "
            f"num_leaf_clusters={self.num_leaf_clusters}, "
            f"LoRA: rank={lora_rank}, alpha={lora_alpha}"
        )
        if self.enable_mixing or self.enable_chaining:
            logger.info(
                f"Multi-model control enabled: mixing={self.enable_mixing}, "
                f"chaining={self.enable_chaining}, alpha={self.alpha}, beta={self.beta}"
            )
    def _fit_from_loader(self, train_loader):
        logger.info(f"Training HierarchicalLoRAMoE ({self.num_levels} levels, "
                   f"branch_factor={self.branch_factor}, "
                   f"num_leaf_clusters={self.num_leaf_clusters})")
        self._bind_base_dataset(train_loader)
        logger.info("=" * 70)
        logger.info("Step 1: MiniBatchKMeans streaming clustering...")
        logger.info("=" * 70)
        leaf_centroids = self._stream_cluster_bottom_layer(train_loader)
        logger.info(f"✓ Got {len(leaf_centroids)} leaf cluster centroids")
        logger.info("=" * 70)
        logger.info("Step 2: Bottom-up tree construction...")
        logger.info("=" * 70)
        self._build_hierarchy_bottom_up(leaf_centroids)
        logger.info(f"✓ Built hierarchy tree with {len(self.tree.nodes)} nodes")
        logger.info("=" * 70)
        logger.info("Step 3: Computing principal directions (PCA)...")
        logger.info("=" * 70)
        if self.enable_mixing or self.enable_chaining:
            self._compute_principal_directions(train_loader)
        logger.info("=" * 70)
        logger.info("Step 4: Training shared base model on all data...")
        logger.info("=" * 70)
        self._train_base_model(train_loader)
        logger.info(f"✓ Trained and froze shared base model")
        logger.info("=" * 70)
        logger.info("Step 5: Training LoRA adapters for leaf nodes...")
        logger.info("=" * 70)
        self._train_lora_adapters_streaming()
        logger.info(f"✓ Trained {len(self.lora_adapters)} LoRA adapters")
        logger.info("=" * 70)
        logger.info("✓ HierarchicalLoRAMoE training completed")
        logger.info("=" * 70)
    def _bind_base_dataset(self, train_loader):
        self.train_loader = train_loader
        self.input_dim = train_loader.source_embedding_dim
        self.output_dim = train_loader.target_embedding_dim
        if hasattr(train_loader, 'total_samples'):
            self.num_samples = train_loader.total_samples
        else:
            logger.warning("Counting samples by iterating (slow)...")
            self.num_samples = 0
            for batch in train_loader:
                src_batch = batch[0] if isinstance(batch, tuple) else batch
                if isinstance(src_batch, torch.Tensor):
                    self.num_samples += src_batch.shape[0]
                else:
                    self.num_samples += len(src_batch)
        logger.info(
            f"Bound train_loader: {self.num_samples:,} samples, "
            f"input_dim={self.input_dim}, output_dim={self.output_dim}"
        )
    def _stream_cluster_bottom_layer(
        self,
        train_loader,
        batch_size: int = 1024
    ) -> np.ndarray:
        logger.info(f"Pass 1: Streaming clustering with {self.num_leaf_clusters} clusters...")
        mbk = MiniBatchKMeans(
            n_clusters=self.num_leaf_clusters,
            batch_size=batch_size,
            random_state=42,
            n_init=3,
            verbose=0
        )
        for src_batch, tgt_batch in tqdm(train_loader, desc="Clustering"):
            if isinstance(src_batch, torch.Tensor):
                src_batch = src_batch.cpu().numpy()
            mbk.partial_fit(src_batch)
        leaf_centroids = mbk.cluster_centers_
        logger.info(f"Learned {len(leaf_centroids)} leaf cluster centers")
        logger.info("Pass 2: Assigning samples to leaf clusters...")
        N = self.num_samples
        sample_to_leaf = np.empty(N, dtype=np.int32)
        global_idx = 0
        for src_batch, tgt_batch in tqdm(train_loader, desc="Assigning"):
            if isinstance(src_batch, torch.Tensor):
                src_batch = src_batch.cpu().numpy()
            labels = mbk.predict(src_batch)
            B = len(labels)
            if global_idx + B > N:
                actual_B = N - global_idx
                sample_to_leaf[global_idx : global_idx + actual_B] = labels[:actual_B]
                global_idx = N
                break
            else:
                sample_to_leaf[global_idx : global_idx + B] = labels
                global_idx += B
        if global_idx != N:
            logger.warning(
                f"Sample count mismatch: processed {global_idx}, expected {N}. "
                f"Truncating sample_to_leaf array."
            )
            sample_to_leaf = sample_to_leaf[:global_idx]
            self.num_samples = global_idx
        leaf_to_indices: List[List[int]] = [[] for _ in range(self.num_leaf_clusters)]
        for gid, leaf_id in enumerate(sample_to_leaf):
            leaf_to_indices[leaf_id].append(gid)
        self.leaf_to_indices = [
            np.array(lst, dtype=np.int64) for lst in leaf_to_indices
        ]
        self.sample_to_leaf = sample_to_leaf
        cluster_sizes = [len(indices) for indices in self.leaf_to_indices]
        logger.info(f"Leaf cluster sizes: min={min(cluster_sizes)}, "
                   f"max={max(cluster_sizes)}, "
                   f"mean={np.mean(cluster_sizes):.1f}")
        return leaf_centroids
    def _build_hierarchy_bottom_up(self, leaf_centroids: np.ndarray):
        self.tree = BottomUpHierarchyTree(self.num_levels, self.branch_factor)
        logger.info("Creating leaf nodes...")
        leaf_level = self.num_levels - 1
        K = leaf_centroids.shape[0]
        for leaf_id in range(K):
            node_id = len(self.tree.nodes)
            node = TreeNode(
                node_id=node_id,
                level=leaf_level,
                centroid=leaf_centroids[leaf_id].copy()
            )
            node.data_indices = self.leaf_to_indices[leaf_id].copy()
            self.tree.nodes.append(node)
            self.tree.level_nodes[leaf_level].append(node_id)
        logger.info(f"Created {K} leaf nodes")
        current_level_ids = self.tree.level_nodes[leaf_level].copy()
        for level in range(self.num_levels - 2, -1, -1):
            logger.info(f"Building level {level} from {len(current_level_ids)} children...")
            num_children = len(current_level_ids)
            num_parents = max(1, num_children // self.branch_factor)
            node_centroids = np.stack(
                [self.tree.nodes[nid].centroid for nid in current_level_ids],
                axis=0
            )
            if num_parents == 1:
                labels = np.zeros(num_children, dtype=np.int32)
                parent_centroids = [node_centroids.mean(axis=0)]
            else:
                kmeans = KMeans(
                    n_clusters=num_parents,
                    random_state=42,
                    n_init=10,
                    verbose=0
                )
                labels = kmeans.fit_predict(node_centroids)
                parent_centroids = kmeans.cluster_centers_
            parent_ids = []
            for p in range(num_parents):
                parent_node_id = len(self.tree.nodes)
                parent = TreeNode(
                    node_id=parent_node_id,
                    level=level,
                    centroid=parent_centroids[p].copy()
                )
                child_ids = [
                    current_level_ids[i]
                    for i, lab in enumerate(labels)
                    if lab == p
                ]
                all_indices = []
                for cid in child_ids:
                    child = self.tree.nodes[cid]
                    child.parent_id = parent_node_id
                    parent.child_ids.append(cid)
                    if child.data_indices is not None and len(child.data_indices) > 0:
                        all_indices.append(child.data_indices)
                if len(all_indices) > 0:
                    parent.data_indices = np.concatenate(all_indices)
                else:
                    parent.data_indices = np.zeros(0, dtype=np.int64)
                self.tree.nodes.append(parent)
                parent_ids.append(parent_node_id)
            self.tree.level_nodes[level] = parent_ids
            current_level_ids = parent_ids
            logger.info(f"Created {len(parent_ids)} parent nodes at level {level}")
        if len(self.tree.level_nodes[0]) > 0:
            self.tree.root_id = self.tree.level_nodes[0][0]
            logger.info(f"Root node ID: {self.tree.root_id}")
        else:
            raise ValueError("Failed to create root node")
        self._log_tree_statistics()
    def _log_tree_statistics(self):
        logger.info("Tree statistics:")
        for level in range(self.num_levels):
            nodes = self.tree.get_nodes_at_level(level)
            total_samples = sum(
                len(node.data_indices) if node.data_indices is not None else 0
                for node in nodes
            )
            logger.info(
                f"  Level {level}: {len(nodes)} nodes, "
                f"{total_samples:,} total samples"
            )
    def _compute_principal_directions(self, train_loader):
        logger.info("Computing PCA on target embeddings to extract principal directions...")
        target_embeddings = []
        for src_batch, tgt_batch in tqdm(train_loader, desc="Collecting target embeddings"):
            if isinstance(tgt_batch, torch.Tensor):
                tgt_batch = tgt_batch.cpu().numpy()
            target_embeddings.append(tgt_batch)
        target_embeddings = np.concatenate(target_embeddings, axis=0)
        logger.info(f"Collected {target_embeddings.shape[0]:,} target embeddings")
        pca = PCA(n_components=2)
        pca.fit(target_embeddings)
        self.u1 = torch.from_numpy(pca.components_[0]).float().to(self.device)
        self.u2 = torch.from_numpy(pca.components_[1]).float().to(self.device)
        explained_var = pca.explained_variance_ratio_
        logger.info(f"✓ Computed principal directions:")
        logger.info(f"  PC1 (u1): explains {explained_var[0]:.2%} variance")
        logger.info(f"  PC2 (u2): explains {explained_var[1]:.2%} variance")
        logger.info(f"  u1 ⊥ u2: {np.abs(np.dot(self.u1.cpu().numpy(), self.u2.cpu().numpy())):.6f}")
    def _train_base_model(self, train_loader):
        logger.info("Training shared base model on all training data...")
        logger.info(f"  Input dim: {self.input_dim}, Output dim: {self.output_dim}")
        logger.info(f"  Epochs: {self.base_model_epochs}")
        if self.enable_mixing:
            logger.info("  Mixing regularization enabled (directing residuals along u1)")
        base_config = self.mapper_config.model_dump()
        base_config['num_epochs'] = self.base_model_epochs
        if self.enable_mixing and self.u1 is not None:
            base_trainer = SimpleLinearMapper(
                input_dim=self.input_dim,
                output_dim=self.output_dim,
                **base_config
            )
            self._train_base_model_with_mixing(base_trainer, train_loader)
        else:
            base_trainer = SimpleLinearMapper(
                input_dim=self.input_dim,
                output_dim=self.output_dim,
                **base_config
            )
            base_trainer.fit_with_loader(train_loader)
        self.base_model = base_trainer.model
        for param in self.base_model.parameters():
            param.requires_grad = False
        self.base_model.to(self.device)
        self.base_model.eval()
        total_params = sum(p.numel() for p in self.base_model.parameters())
        logger.info(f"✓ Base model trained and frozen: {total_params:,} parameters")
    def _train_base_model_with_mixing(self, base_trainer, train_loader):
        logger.info("Training base model with mixing regularization...")
        model = base_trainer.model.to(self.device)
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=base_trainer.learning_rate
        )
        criterion = nn.MSELoss()
        num_epochs = base_trainer.num_epochs
        model.train()
        for epoch in range(num_epochs):
            total_loss = 0.0
            total_mix_loss = 0.0
            num_batches = 0
            for src_batch, tgt_batch in train_loader:
                src_batch = src_batch.to(self.device)
                tgt_batch = tgt_batch.to(self.device)
                optimizer.zero_grad()
                outputs = model(src_batch)
                reg_loss = criterion(outputs, tgt_batch)
                residuals = outputs - tgt_batch
                residual_proj_u1 = (residuals @ self.u1).unsqueeze(1) * self.u1.unsqueeze(0)
                mix_loss = torch.mean(torch.sum((residuals - residual_proj_u1) ** 2, dim=1))
                loss = reg_loss + self.beta * mix_loss
                loss.backward()
                optimizer.step()
                total_loss += reg_loss.item()
                total_mix_loss += mix_loss.item()
                num_batches += 1
            avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
            avg_mix_loss = total_mix_loss / num_batches if num_batches > 0 else 0.0
            if epoch == 0 or (epoch + 1) % 50 == 0:
                logger.debug(
                    f"  Epoch {epoch+1}/{num_epochs}, "
                    f"Reg Loss: {avg_loss:.6f}, "
                    f"Mix Loss: {avg_mix_loss:.6f}"
                )
        model.eval()
        base_trainer.model = model
    def _train_lora_adapters_streaming(self, min_samples: int = 10):
        leaf_level = self.num_levels - 1
        leaf_nodes = self.tree.get_nodes_at_level(leaf_level)
        logger.info(f"Training LoRA adapters for {len(leaf_nodes)} leaf nodes...")
        logger.info(f"  LoRA config: rank={self.lora_config.rank}, "
                   f"alpha={self.lora_config.alpha}, dropout={self.lora_config.dropout}")
        logger.info(f"  Epochs per adapter: {self.lora_epochs}")
        logger.info(f"  Learning rate: {self.lora_learning_rate}")
        trained_count = 0
        skipped_count = 0
        for node in tqdm(leaf_nodes, desc="Training LoRA adapters"):
            if node.data_indices is None or len(node.data_indices) < min_samples:
                logger.warning(
                    f"Leaf node {node.node_id}: too few samples "
                    f"({len(node.data_indices) if node.data_indices is not None else 0}), "
                    f"skipping"
                )
                skipped_count += 1
                continue
            node_loader = self._create_node_loader(
                node.data_indices,
                batch_size=min(1024, len(node.data_indices)),
                shuffle=True
            )
            lora_expert = LoRAExpert(
                base_model=self.base_model,
                lora_config=self.lora_config,
                freeze_base=True
            )
            lora_expert.to(self.device)
            logger.debug(f"Training LoRA adapter for node {node.node_id} ({len(node.data_indices):,} samples)...")
            self._train_single_lora_adapter(lora_expert, node_loader)
            self.lora_adapters[node.node_id] = lora_expert
            trained_count += 1
            del node_loader
        logger.info(f"LoRA adapter training complete:")
        logger.info(f"  Trained: {trained_count}")
        logger.info(f"  Skipped: {skipped_count}")
        if len(self.lora_adapters) > 0:
            base_params = sum(p.numel() for p in self.base_model.parameters())
            lora_params_per_adapter = sum(
                p.numel() for p in next(iter(self.lora_adapters.values())).get_lora_parameters()
            )
            total_lora_params = lora_params_per_adapter * trained_count
            total_params = base_params + total_lora_params
            logger.info(f"Parameter statistics:")
            logger.info(f"  Base model: {base_params:,}")
            logger.info(f"  LoRA per adapter: {lora_params_per_adapter:,}")
            logger.info(f"  Total LoRA: {total_lora_params:,} ({trained_count} adapters)")
            logger.info(f"  Grand total: {total_params:,}")
            logger.info(f"  Compression ratio: {base_params * trained_count / total_params:.1f}x vs full experts")
    def _train_single_lora_adapter(self, lora_expert: LoRAExpert, node_loader):
        optimizer = torch.optim.Adam(
            lora_expert.get_lora_parameters(),
            lr=self.lora_learning_rate
        )
        criterion = nn.MSELoss()
        lora_expert.train()
        for epoch in range(self.lora_epochs):
            total_loss = 0.0
            total_chain_loss = 0.0
            num_batches = 0
            for src_batch, tgt_batch in node_loader:
                src_batch = src_batch.to(self.device)
                tgt_batch = tgt_batch.to(self.device)
                optimizer.zero_grad()
                outputs = lora_expert(src_batch)
                reg_loss = criterion(outputs, tgt_batch)
                loss = reg_loss
                if self.enable_chaining and self.u2 is not None:
                    residuals = outputs - tgt_batch
                    residual_proj_u2 = (residuals @ self.u2).unsqueeze(1) * self.u2.unsqueeze(0)
                    chain_loss = torch.mean(torch.sum((residuals - residual_proj_u2) ** 2, dim=1))
                    loss = loss + self.beta * chain_loss
                    total_chain_loss += chain_loss.item()
                loss.backward()
                optimizer.step()
                total_loss += reg_loss.item()
                num_batches += 1
            avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
            if epoch == 0 or (epoch + 1) % 5 == 0:
                if self.enable_chaining:
                    avg_chain_loss = total_chain_loss / num_batches if num_batches > 0 else 0.0
                    logger.debug(
                        f"  Epoch {epoch+1}/{self.lora_epochs}, "
                        f"Reg Loss: {avg_loss:.6f}, "
                        f"Chain Loss: {avg_chain_loss:.6f}"
                    )
                else:
                    logger.debug(f"  Epoch {epoch+1}/{self.lora_epochs}, Loss: {avg_loss:.6f}")
        lora_expert.eval()
    def _create_node_loader(
        self,
        indices: np.ndarray,
        batch_size: int = 1024,
        shuffle: bool = True
    ):
        from torch.utils.data import Dataset, DataLoader
        if self.train_loader is None:
            raise ValueError("train_loader not bound. Call _bind_base_dataset() first.")
        class MemmapNodeDataset(Dataset):
            def __init__(self, loader, indices: np.ndarray):
                self.loader = loader
                self.indices = indices
            def __len__(self) -> int:
                return len(self.indices)
            def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
                global_idx = self.indices[idx]
                ds_idx, local_idx = self.loader._find_dataset_idx(global_idx)
                dataset = self.loader.datasets[ds_idx]
                src_emb = dataset.source_embeddings[local_idx]
                tgt_emb = dataset.target_embeddings[local_idx]
                return (
                    torch.from_numpy(src_emb.copy()).float(),
                    torch.from_numpy(tgt_emb.copy()).float()
                )
        dataset = MemmapNodeDataset(self.train_loader, indices)
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=0,
            pin_memory=torch.cuda.is_available(),
            drop_last=False
        )
    def _compute_distances(
        self,
        embeddings: np.ndarray,
        centroids: np.ndarray
    ) -> np.ndarray:
        if self.distance_metric == "cosine":
            from sklearn.metrics.pairwise import cosine_distances
            return cosine_distances(embeddings, centroids)
        elif self.distance_metric == "euclidean":
            from sklearn.metrics.pairwise import euclidean_distances
            return euclidean_distances(embeddings, centroids)
        else:
            raise ValueError(f"Unsupported distance metric: {self.distance_metric}")
    def _route_single_to_leaf(self, embedding: np.ndarray) -> int:
        if self.tree is None:
            raise ValueError("Tree not built. Call fit() first.")
        current_node_id = self.tree.root_id
        embedding_reshaped = embedding.reshape(1, -1)
        while True:
            node = self.tree.nodes[current_node_id]
            if len(node.child_ids) == 0:
                return current_node_id
            child_centroids = np.stack(
                [self.tree.nodes[cid].centroid for cid in node.child_ids],
                axis=0
            )
            distances = self._compute_distances(embedding_reshaped, child_centroids)[0]
            nearest_child_idx = np.argmin(distances)
            current_node_id = node.child_ids[nearest_child_idx]
    def get_expert_assignments(self, embeddings: np.ndarray) -> np.ndarray:
        if self.tree is None:
            raise ValueError("Tree not built. Call fit() first.")
        n_samples = embeddings.shape[0]
        expert_ids = np.empty(n_samples, dtype=np.int32)
        for i in range(n_samples):
            expert_ids[i] = self._route_single_to_leaf(embeddings[i])
        return expert_ids
    def transform(self, embeddings: np.ndarray) -> np.ndarray:
        if self.tree is None:
            raise ValueError("Tree not built. Call fit() first.")
        if self.base_model is None:
            raise ValueError("Base model not trained. Call fit() first.")
        n_samples = embeddings.shape[0]
        results = np.zeros((n_samples, self.output_dim), dtype=embeddings.dtype)
        logger.info(f"Transforming {n_samples:,} embeddings using hierarchical LoRA routing...")
        expert_ids = self.get_expert_assignments(embeddings)
        unique_experts, counts = np.unique(expert_ids, return_counts=True)
        logger.info(f"Assignment distribution: {len(unique_experts)} experts used")
        for expert_id, count in zip(unique_experts, counts):
            logger.debug(f"  Expert {expert_id}: {count:,} samples ({count/n_samples*100:.1f}%)")
        self.base_model.eval()
        with torch.no_grad():
            for expert_id in unique_experts:
                mask = (expert_ids == expert_id)
                expert_embeddings = embeddings[mask]
                expert_embeddings_tensor = torch.from_numpy(expert_embeddings).float().to(self.device)
                if expert_id in self.lora_adapters:
                    lora_expert = self.lora_adapters[expert_id]
                    lora_expert.eval()
                    transformed_tensor = lora_expert(expert_embeddings_tensor)
                else:
                    logger.warning(
                        f"Leaf node {expert_id} has no LoRA adapter. "
                        f"Using base model only."
                    )
                    transformed_tensor = self.base_model(expert_embeddings_tensor)
                transformed = transformed_tensor.cpu().numpy()
                results[mask] = transformed
        return results
    def save_lora_adapters(self, save_dir: str):
        import os
        os.makedirs(save_dir, exist_ok=True)
        for node_id, lora_expert in self.lora_adapters.items():
            adapter_path = os.path.join(save_dir, f"lora_adapter_{node_id}.pt")
            lora_expert.save_lora_weights(adapter_path)
        logger.info(f"Saved {len(self.lora_adapters)} LoRA adapters to {save_dir}")
    def load_lora_adapters(self, save_dir: str):
        import os
        for node_id in self.lora_adapters.keys():
            adapter_path = os.path.join(save_dir, f"lora_adapter_{node_id}.pt")
            if os.path.exists(adapter_path):
                self.lora_adapters[node_id].load_lora_weights(adapter_path)
        logger.info(f"Loaded LoRA adapters from {save_dir}")
