# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.neighbors import KernelDensity
import traceback


def precompute_density_and_boundaries(y_train_npy, kde_bandwidth='silverman', density_percentiles=[33.3, 66.7]):
    """
    Fit KDE on training targets and compute per-sample densities and density boundaries.
    
    Args:
        y_train_npy: Training target values with shape (n_samples, 1).
        kde_bandwidth: KDE bandwidth parameter.
        density_percentiles: Percentile thresholds [low, high] for density boundaries.
    
    Returns:
        Tuple of (densities, boundaries_dict, kde_model) or (None, None, None) on error.
    """
    print("CASMIR: Computing density and boundaries...")
    if y_train_npy.ndim != 2 or y_train_npy.shape[1] != 1:
        print(f"Warning: y_train_npy should be (n_samples, 1). Current shape: {y_train_npy.shape}")

    if len(y_train_npy) == 0:
        print("Error: Empty y_train_npy for density computation.")
        return None, None, None

    try:
        kde = KernelDensity(kernel='gaussian', bandwidth=kde_bandwidth).fit(y_train_npy)
        log_dens_train = kde.score_samples(y_train_npy)
        densities_train_np = np.exp(log_dens_train)

        if np.any(densities_train_np <= 1e-9):
            print(f"Warning: Very small density values detected: {np.min(densities_train_np)}")

        print(f"Density computed. Min: {densities_train_np.min():.4e}, Max: {densities_train_np.max():.4e}")

        unique_densities = np.unique(densities_train_np)
        if len(unique_densities) > 1:
            density_low_threshold = np.percentile(densities_train_np, density_percentiles[0])
            density_high_threshold = np.percentile(densities_train_np, density_percentiles[1])

            if np.isclose(density_low_threshold, density_high_threshold) or density_low_threshold > density_high_threshold:
                median_dens = np.median(densities_train_np)
                min_dens, max_dens = np.min(densities_train_np), np.max(densities_train_np)
                if np.isclose(min_dens, max_dens):
                    density_low_threshold = min_dens - 1e-6
                    density_high_threshold = max_dens + 1e-6
                else:
                    density_low_threshold = median_dens - (median_dens - min_dens) * 0.1 if median_dens > min_dens else median_dens - 1e-6
                    density_high_threshold = median_dens + (max_dens - median_dens) * 0.1 if max_dens > median_dens else median_dens + 1e-6
                    if density_low_threshold >= density_high_threshold:
                        density_low_threshold = min_dens
                        density_high_threshold = max_dens
                print(f"Warning: Adjusted thresholds -> Low: {density_low_threshold:.4e}, High: {density_high_threshold:.4e}")
        else:
            density_low_threshold = unique_densities[0] - 1e-6
            density_high_threshold = unique_densities[0] + 1e-6
            print(f"Warning: Uniform density, using arbitrary thresholds.")

        density_boundaries = {'low': density_low_threshold, 'high': density_high_threshold}
        print(f"Density boundaries (percentiles {density_percentiles}): Low < {density_low_threshold:.4e} <= Med <= {density_high_threshold:.4e} < High")

        return densities_train_np, density_boundaries, kde

    except Exception as e:
        print(f"Error: Density computation failed - {e}")
        traceback.print_exc()
        return None, None, None


class CAS_Module(nn.Module):
    """
    Coupled Adaptive Smoothing (CAS) Module.
    Includes learnable feature weights for adaptive distance computation.
    """
    
    def __init__(self, input_dim, k, feature_bw, label_bw, density_factor, 
                 strength_base, density_c, epsilon=1e-6, 
                 weight_init_strategy='random', adaptive_k=False):
        """
        Args:
            input_dim: Input feature dimension.
            k: Number of neighbors for k-NN.
            feature_bw: Feature bandwidth for Gaussian kernel.
            label_bw: Label bandwidth for Gaussian kernel.
            density_factor: Factor for density-based weighting.
            strength_base: Base smoothing strength.
            density_c: Density constant for smoothing strength computation.
            epsilon: Small constant for numerical stability.
            weight_init_strategy: Feature weight initialization ('random', 'uniform', 'zeros').
            adaptive_k: Whether to adapt k based on batch size.
        """
        super().__init__()
        self.input_dim = input_dim
        self.k = int(k)
        self.feature_bw = feature_bw
        self.label_bw = label_bw
        self.density_factor = density_factor
        self.strength_base = strength_base
        self.density_c = density_c
        self.epsilon = epsilon
        
        if weight_init_strategy == 'random':
            self.raw_feature_weights = nn.Parameter(torch.randn(input_dim) * 0.1 + 1.0)
        elif weight_init_strategy == 'uniform':
            self.raw_feature_weights = nn.Parameter(torch.ones(input_dim))
        else:
            self.raw_feature_weights = nn.Parameter(torch.zeros(input_dim))
        
        self.adaptive_k = adaptive_k
    
    def forward(self, x_batch, y_batch, densities_batch):
        """Apply batch-wise CAS smoothing."""
        batch_size = x_batch.shape[0]
        device = x_batch.device

        if self.adaptive_k:
            effective_k = min(self.k, max(3, batch_size // 4))
        else:
            effective_k = self.k
        
        if batch_size <= effective_k:
            print(f"Warning: CAS - batch size ({batch_size}) <= effective_k ({effective_k}). Returning original features.")
            return x_batch

        positive_feature_weights = F.softplus(self.raw_feature_weights)

        diff_sq = (x_batch.unsqueeze(1) - x_batch.unsqueeze(0))**2
        weighted_diff_sq = diff_sq * positive_feature_weights.view(1, 1, -1)
        dist_sq_matrix = torch.sum(weighted_diff_sq, dim=-1)
        dist_sq_matrix_stable = F.relu(dist_sq_matrix)

        try:
            current_k = min(effective_k, batch_size - 1)
            if current_k < 1:
                print(f"Warning: CAS - effective K ({current_k}) < 1. Returning original features.")
                return x_batch

            _, nn_indices = torch.topk(dist_sq_matrix_stable, current_k + 1, dim=1, largest=False, sorted=True)
            nn_indices = nn_indices[:, 1:]
        except RuntimeError as e:
            print(f"Error: torch.topk failed. K={current_k+1}, batch_size={batch_size}. Error: {e}")
            return x_batch

        neighbor_x = x_batch[nn_indices]
        neighbor_y = y_batch[nn_indices]

        if densities_batch.ndim > 1:
            densities_batch = densities_batch.squeeze()
            
        try:
            clamped_indices = torch.clamp(nn_indices, 0, densities_batch.shape[0] - 1)
            neighbor_densities = densities_batch[clamped_indices]
        except IndexError as e:
            print(f"Error: Neighbor density indexing failed. Error: {e}")
            neighbor_densities = torch.full_like(nn_indices, densities_batch.mean(), dtype=torch.float32, device=device)

        dist_sq_neighbors = torch.gather(dist_sq_matrix_stable, 1, nn_indices)
        sim_x = torch.exp(-dist_sq_neighbors / (2 * self.feature_bw**2))

        dist_sq_y = (y_batch - neighbor_y.squeeze(-1))**2
        sim_y_base = torch.exp(-dist_sq_y / (2 * self.label_bw**2))
        if self.density_factor > 0:
            density_weight = self.density_factor / (neighbor_densities + self.epsilon)
            sim_y = sim_y_base * density_weight
        else:
            sim_y = sim_y_base

        weights_unnorm = sim_x * sim_y
        weights_sum = torch.sum(weights_unnorm, dim=1, keepdim=True) + self.epsilon
        weights = weights_unnorm / weights_sum

        mu_smooth = torch.sum(weights.unsqueeze(-1) * neighbor_x, dim=1)

        smoothing_strength = self.strength_base / (densities_batch * self.density_c + self.epsilon)
        smoothing_strength = torch.clamp(smoothing_strength, min=0.05, max=0.8)

        x_calibrated = (1 - smoothing_strength.unsqueeze(1)) * x_batch + smoothing_strength.unsqueeze(1) * mu_smooth

        return x_calibrated


class Expert(nn.Module):
    """Simple MLP expert network."""
    
    def __init__(self, input_dim, hidden_dims, use_layer_norm=True, dropout_rate=0.0):
        super().__init__()
        layers = []
        last_dim = input_dim
        for i, hidden_dim in enumerate(hidden_dims):
            layers.append(nn.Linear(last_dim, hidden_dim))
            layers.append(nn.ReLU())
            
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim))
            else:
                layers.append(nn.BatchNorm1d(hidden_dim))
            
            if dropout_rate > 0:
                layers.append(nn.Dropout(dropout_rate))
            
            last_dim = hidden_dim
        layers.append(nn.Linear(last_dim, 1))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class TabularExpert(nn.Module):
    """Expert optimized for tabular data with LayerNorm, GELU, Xavier init, and optional residual."""
    
    def __init__(self, input_dim, hidden_dims, dropout_rate=0.2, use_residual=True):
        super().__init__()
        self.use_residual = use_residual and (input_dim != 1)

        layers = []
        last_dim = input_dim
        for h in hidden_dims:
            lin = nn.Linear(last_dim, h)
            nn.init.xavier_uniform_(lin.weight)
            nn.init.zeros_(lin.bias)
            layers.append(lin)
            layers.append(nn.LayerNorm(h))
            layers.append(nn.GELU())
            if dropout_rate > 0:
                layers.append(nn.Dropout(dropout_rate))
            last_dim = h

        out = nn.Linear(last_dim, 1)
        nn.init.xavier_uniform_(out.weight, gain=0.1)
        nn.init.zeros_(out.bias)
        layers.append(out)

        self.network = nn.Sequential(*layers)

        if self.use_residual:
            self.res_proj = nn.Linear(input_dim, 1)
            nn.init.xavier_uniform_(self.res_proj.weight, gain=0.1)
            nn.init.zeros_(self.res_proj.bias)

    def forward(self, x):
        y = self.network(x)
        if self.use_residual:
            y = y + 0.1 * self.res_proj(x)
        return y


class BasicExpert(nn.Module):
    """Expert with BatchNorm + ReLU structure (same as basic_models.py MLP)."""
    
    def __init__(self, input_dim, hidden_dims, dropout_rate=0.1):
        super().__init__()
        
        layers = []
        last_dim = input_dim
        
        for hidden_dim in hidden_dims:
            linear = nn.Linear(last_dim, hidden_dim)
            nn.init.xavier_uniform_(linear.weight)
            nn.init.zeros_(linear.bias)
            layers.append(linear)
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.Dropout(dropout_rate))
            last_dim = hidden_dim
        
        output_layer = nn.Linear(last_dim, 1)
        nn.init.xavier_uniform_(output_layer.weight, gain=0.1)
        nn.init.zeros_(output_layer.bias)
        layers.append(output_layer)
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)


class GatingNetwork(nn.Module):
    """Gating network that outputs expert selection logits."""
    
    def __init__(self, input_dim, hidden_dims, num_experts, dropout_rate=0.0):
        super().__init__()
        layers = []
        last_dim = input_dim
        for i, hidden_dim in enumerate(hidden_dims):
            linear = nn.Linear(last_dim, hidden_dim)
            nn.init.xavier_uniform_(linear.weight)
            nn.init.zeros_(linear.bias)
            layers.append(linear)
            layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.GELU())
            if dropout_rate > 0:
                layers.append(nn.Dropout(dropout_rate))
            last_dim = hidden_dim
            
        output_layer = nn.Linear(last_dim, num_experts)
        nn.init.xavier_uniform_(output_layer.weight, gain=0.1)
        nn.init.zeros_(output_layer.bias)
        layers.append(output_layer)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


class FeatureNormalizer(nn.Module):
    """Learnable per-feature scale and shift for input conditioning."""
    
    def __init__(self, input_dim, learnable=True):
        super().__init__()
        if learnable:
            self.scales = nn.Parameter(torch.ones(input_dim))
            self.shifts = nn.Parameter(torch.zeros(input_dim))
        else:
            self.register_buffer('scales', torch.ones(input_dim))
            self.register_buffer('shifts', torch.zeros(input_dim))
        self.learnable = learnable

    def forward(self, x):
        return x * self.scales + self.shifts


class CASMIR_V1(nn.Module):
    """
    CASMIR model v1 (CAS + MoE).
    
    Args:
        input_dim: Input feature dimension.
        num_experts: Number of expert networks.
        expert_hidden_dims: List of hidden dimensions for experts.
        gate_hidden_dims: List of hidden dimensions for gating network.
        cas_params: Dictionary of CAS module parameters.
        use_feature_norm: Whether to use learnable feature normalization.
        expert_dropout: Dropout rate for experts.
        gate_temperature: Initial temperature for gating softmax.
        expert_type: Expert architecture type ('tabular' or 'basic').
    """
    
    def __init__(self, input_dim, num_experts, expert_hidden_dims, gate_hidden_dims, 
                 cas_params, use_feature_norm=True, expert_dropout=0.2, gate_temperature=1.0,
                 expert_type='tabular'):
        super().__init__()
        self.use_feature_norm = use_feature_norm
        
        if use_feature_norm:
            self.feature_norm = FeatureNormalizer(input_dim, learnable=True)
        
        self.cas_module = CAS_Module(input_dim=input_dim, **cas_params)
        
        self.expert_type = expert_type.lower()
        if self.expert_type == 'basic':
            self.experts = nn.ModuleList([
                BasicExpert(input_dim, expert_hidden_dims, dropout_rate=expert_dropout)
                for _ in range(num_experts)
            ])
        elif self.expert_type == 'tabular':
            self.experts = nn.ModuleList([
                TabularExpert(input_dim, expert_hidden_dims, dropout_rate=expert_dropout)
                for _ in range(num_experts)
            ])
        else:
            raise ValueError(f"Unsupported expert_type: {expert_type}. Use 'tabular' or 'basic'.")
        
        self.gating_network = GatingNetwork(input_dim, gate_hidden_dims, num_experts, dropout_rate=expert_dropout*0.5)
        self.gate_temperature = nn.Parameter(torch.tensor(gate_temperature))

    def forward(self, x, y=None, density=None, apply_smoothing=True):
        x_norm = self.feature_norm(x) if self.use_feature_norm else x
        
        if self.training and apply_smoothing and y is not None and density is not None:
            x_processed = self.cas_module(x_norm, y, density)
        else:
            x_processed = x_norm

        gate_logits = self.gating_network(x_norm)
        gate_weights = F.softmax(gate_logits / self.gate_temperature.clamp(min=0.5), dim=1)

        expert_outputs = torch.stack([expert(x_processed) for expert in self.experts], dim=1)
        final_prediction = torch.sum(gate_weights.unsqueeze(-1) * expert_outputs, dim=1)

        if self.training and apply_smoothing:
            return final_prediction, gate_weights, gate_logits
        else:
            return final_prediction


class CASMIR_V1_OLD(nn.Module):
    """Legacy CASMIR model v1 (CAS + MoE). Kept for backward compatibility."""
    
    def __init__(self, input_dim, num_experts, expert_hidden_dims, gate_hidden_dims, cas_params):
        super().__init__()
        self.num_experts = num_experts
        self.cas_module = CAS_Module(input_dim=input_dim, **cas_params)
        self.experts = nn.ModuleList(
            [Expert(input_dim, expert_hidden_dims) for _ in range(num_experts)]
        )
        self.gating_network = GatingNetwork(input_dim, gate_hidden_dims, num_experts)

    def forward(self, x, y=None, density=None, apply_smoothing=True):
        if self.training and apply_smoothing and y is not None and density is not None:
            x_processed = self.cas_module(x, y, density)
        else:
            x_processed = x

        gate_logits = self.gating_network(x)
        gate_weights = F.softmax(gate_logits, dim=1)

        expert_outputs = torch.stack([expert(x_processed) for expert in self.experts], dim=1)
        final_prediction = torch.sum(gate_weights.unsqueeze(-1) * expert_outputs, dim=1)

        if self.training and apply_smoothing:
            return final_prediction, gate_weights, gate_logits
        else:
            return final_prediction

    def get_feature_weights(self):
        """Return learned feature weights (positive)."""
        weights = self.cas_module.raw_feature_weights.detach()
        return F.softplus(weights)


def get_density_range_index(density_value, low_thresh, high_thresh):
    """Return range index (0, 1, 2) based on density value."""
    if density_value < low_thresh:
        return 0
    elif density_value <= high_thresh:
        return 1
    else:
        return 2


# Backward compatibility aliases
CFLCS_Module = CAS_Module
AdaSmoothMoEV1 = CASMIR_V1
AdaSmoothMoEV1_OLD = CASMIR_V1_OLD

