# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.preprocessing import KBinsDiscretizer
from src.utils.deprecation import deprecated
from scipy.ndimage import convolve1d, gaussian_filter1d
from scipy.signal.windows import triang


class FDSLayer(nn.Module):
    """
    Feature Distribution Smoothing Layer.
    Uses KBinsDiscretizer to bin target values and manages running statistics internally.
    """
    
    def __init__(self, feature_dim, fds_config, discretizer):
        """
        Args:
            feature_dim: Dimension of input feature vector.
            fds_config: FDS hyperparameter dictionary with keys:
                - num_target_bins: Number of target bins.
                - fds_momentum: Momentum for running statistics update.
                - start_update_epoch: Epoch to start updating statistics.
                - start_smooth_epoch: Epoch to start smoothing.
                - kernel: Smoothing kernel type ('gaussian', 'triang', 'laplace').
                - kernel_size: Kernel size (odd number).
                - kernel_sigma: Kernel sigma for gaussian/laplace.
            discretizer: Fitted KBinsDiscretizer instance.
        """
        super(FDSLayer, self).__init__()
        self.feature_dim = feature_dim
        
        def safe_convert_to_scalar(value, default, convert_func=int):
            """Safely convert numpy array or tensor to scalar."""
            if value is None:
                return convert_func(default)
            if hasattr(value, '__iter__') and not isinstance(value, str):
                if hasattr(value, 'item'):
                    return convert_func(value.item())
                elif hasattr(value, '__getitem__'):
                    return convert_func(value[0])
                else:
                    return convert_func(value)
            else:
                return convert_func(value)
        
        self.num_target_bins = safe_convert_to_scalar(fds_config.get('num_target_bins', 50), 50, int)
        self.fds_momentum = safe_convert_to_scalar(fds_config.get('fds_momentum', 0.9), 0.9, float)
        self.start_update_epoch = safe_convert_to_scalar(fds_config.get('start_update_epoch', 0), 0, int)
        self.start_smooth_epoch = safe_convert_to_scalar(fds_config.get('start_smooth_epoch', 1), 1, int)
        self.discretizer = discretizer
        self.epsilon = 1e-6
        
        bin_edges = torch.tensor(discretizer.bin_edges_[0], dtype=torch.float32)
        self.register_buffer('bin_edges', bin_edges)

        self.kernel = fds_config.get('kernel', 'gaussian')
        self.kernel_size = safe_convert_to_scalar(fds_config.get('kernel_size', 5), 5, int)
        self.kernel_sigma = safe_convert_to_scalar(fds_config.get('kernel_sigma', 2), 2, float)
            
        self.half_ks = (self.kernel_size - 1) // 2
        self.kernel_window = self._get_kernel_window()

        if self.discretizer is None or not hasattr(self.discretizer, 'bin_edges_'):
            raise ValueError("FDSLayer requires a fitted KBinsDiscretizer instance.")
        if self.discretizer.n_bins != self.num_target_bins:
            print(f"Warning: num_target_bins ({self.num_target_bins}) differs from "
                  f"Discretizer n_bins ({self.discretizer.n_bins}). Using Discretizer setting.")
            self.num_target_bins = int(self.discretizer.n_bins)

        self.register_buffer('running_mean', torch.zeros(self.num_target_bins, feature_dim))
        self.register_buffer('running_var', torch.ones(self.num_target_bins, feature_dim))
        self.register_buffer('running_mean_last_epoch', torch.zeros(self.num_target_bins, feature_dim))
        self.register_buffer('running_var_last_epoch', torch.ones(self.num_target_bins, feature_dim))
        self.register_buffer('smoothed_mean_last_epoch', torch.zeros(self.num_target_bins, feature_dim))
        self.register_buffer('smoothed_var_last_epoch', torch.ones(self.num_target_bins, feature_dim))
        self.register_buffer('num_samples_tracked', torch.zeros(self.num_target_bins))

        print(f"FDSLayer initialized: feature_dim={feature_dim}, num_bins={self.num_target_bins}, "
              f"momentum={self.fds_momentum}, kernel={self.kernel}({self.kernel_size}/{self.kernel_sigma})")

    def smooth(self, features, labels, epoch):
        """Deprecated: use forward() instead."""
        return self.forward(features, labels, epoch)

    def update_last_epoch_stats(self, epoch):
        """Update epoch-level statistics with HPO compatibility."""
        if not hasattr(self, '_last_updated_epoch'):
            self._last_updated_epoch = -1
            
        if epoch > self._last_updated_epoch:
            old_epoch = self._last_updated_epoch
            self._last_updated_epoch = epoch
            self._update_last_epoch_stats()
            print(f"Updated smoothed statistics on Epoch [{epoch}]! (previous: {old_epoch})")
        elif epoch == self._last_updated_epoch:
            pass
        else:
            print(f"FDS epoch reset detected: {epoch} < {self._last_updated_epoch}, resetting...")
            self._last_updated_epoch = epoch
            self._update_last_epoch_stats()
            print(f"Updated smoothed statistics on Epoch [{epoch}]! (reset)")

    def _get_kernel_window(self):
        """Generate smoothing kernel."""
        assert self.kernel in ['gaussian', 'triang', 'laplace']
        base_kernel = [0.] * self.half_ks + [1.] + [0.] * self.half_ks
        base_kernel = np.array(base_kernel, dtype=np.float32)
        
        if self.kernel == 'gaussian':
            kernel_window = gaussian_filter1d(base_kernel, sigma=self.kernel_sigma)
        elif self.kernel == 'triang':
            kernel_window = triang(self.kernel_size)
        else:
            laplace = lambda x: np.exp(-abs(x) / self.kernel_sigma) / (2. * self.kernel_sigma)
            kernel_window = list(map(laplace, np.arange(-self.half_ks, self.half_ks + 1)))
            
        kernel_window = kernel_window / np.sum(kernel_window)
        return torch.tensor(kernel_window, dtype=torch.float32)

    def _get_bucket_indices_gpu(self, targets):
        """GPU-optimized bucket index computation using PyTorch."""
        if not torch.is_tensor(targets):
            targets = torch.tensor(targets, dtype=torch.float32, device=self.bin_edges.device)
        else:
            targets = targets.to(self.bin_edges.device)
        
        targets_flat = targets.flatten()
        min_edge = self.bin_edges[0]
        max_edge = self.bin_edges[-1]
        targets_clipped = torch.clamp(targets_flat, min_edge, max_edge)
        bucket_indices = torch.bucketize(targets_clipped, self.bin_edges[1:-1], right=True)
        bucket_indices = torch.clamp(bucket_indices, 0, self.num_target_bins - 1)
        
        return bucket_indices
    
    def _get_bucket_indices(self, targets):
        """Wrapper for backward compatibility (returns NumPy array)."""
        return self._get_bucket_indices_gpu(targets).cpu().numpy()

    def _update_last_epoch_stats(self):
        """Update and smooth previous epoch statistics."""
        self.running_mean_last_epoch.copy_(self.running_mean)
        self.running_var_last_epoch.copy_(self.running_var)

        kernel_window = self.kernel_window.to(self.running_mean.device)
        
        padded_mean = F.pad(
            self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0),
            pad=(self.half_ks, self.half_ks),
            mode='reflect'
        )
        self.smoothed_mean_last_epoch.copy_(
            F.conv1d(padded_mean, kernel_window.view(1, 1, -1), padding=0)
            .permute(2, 1, 0).squeeze(1)
        )
        
        padded_var = F.pad(
            self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0),
            pad=(self.half_ks, self.half_ks),
            mode='reflect'
        )
        self.smoothed_var_last_epoch.copy_(
            F.conv1d(padded_var, kernel_window.view(1, 1, -1), padding=0)
            .permute(2, 1, 0).squeeze(1)
        )

    @torch.no_grad()
    def update_running_stats(self, features, targets_np, epoch):
        """Update running statistics on GPU."""
        if epoch < self.start_update_epoch:
            return

        device = features.device
        bucket_indices_gpu = self._get_bucket_indices_gpu(targets_np)

        for i in range(self.num_target_bins):
            bin_mask = (bucket_indices_gpu == i)
            
            if bin_mask.sum() > 0:
                bin_features = features[bin_mask]
                curr_num_sample = bin_features.size(0)
                
                if curr_num_sample == 0:
                    continue
                
                self.num_samples_tracked[i] += curr_num_sample
                factor = min(self.fds_momentum, curr_num_sample / (self.num_samples_tracked[i] + 1e-8))
                factor = 0 if epoch == self.start_update_epoch else factor

                batch_mean = bin_features.mean(dim=0)
                
                if curr_num_sample == 1:
                    batch_var = torch.zeros_like(batch_mean)
                else:
                    batch_var = bin_features.var(dim=0, unbiased=True)
                    batch_var = torch.clamp(batch_var, min=1e-8, max=1e8)
                
                batch_mean = torch.clamp(batch_mean, min=-1e8, max=1e8)
                
                self.running_mean[i] = self.running_mean[i] * (1 - factor) + batch_mean * factor
                self.running_var[i] = self.running_var[i] * (1 - factor) + batch_var * factor

        for bucket in range(self.num_target_bins):
            if self.num_samples_tracked[bucket] == 0:
                if bucket == 0:
                    if self.num_target_bins > 1:
                        self.running_mean[bucket] = self.running_mean[bucket + 1]
                        self.running_var[bucket] = self.running_var[bucket + 1]
                elif bucket == self.num_target_bins - 1:
                    self.running_mean[bucket] = self.running_mean[bucket - 1]
                    self.running_var[bucket] = self.running_var[bucket - 1]
                else:
                    self.running_mean[bucket] = (self.running_mean[bucket - 1] + 
                                                self.running_mean[bucket + 1]) / 2.0
                    self.running_var[bucket] = (self.running_var[bucket - 1] + 
                                               self.running_var[bucket + 1]) / 2.0

    def calibrate_mean_var(self, features, m1, v1, m2, v2):
        """Calibrate feature distribution."""
        if torch.isnan(m1).any() or torch.isnan(v1).any() or torch.isnan(m2).any() or torch.isnan(v2).any():
            return features
        
        if torch.isinf(m1).any() or torch.isinf(v1).any() or torch.isinf(m2).any() or torch.isinf(v2).any():
            return features
        
        if torch.sum(v1) < self.epsilon:
            return features
            
        if (v1 == 0.).any():
            valid = (v1 != 0.)
            factor = torch.clamp(v2[valid] / v1[valid], 0.1, 10)
            features[:, valid] = (features[:, valid] - m1[valid]) * torch.sqrt(factor) + m2[valid]
            return features

        factor = torch.clamp(v2 / v1, 0.1, 10)
        result = (features - m1) * torch.sqrt(factor) + m2
        result = torch.clamp(result, min=-1e8, max=1e8)
        
        return result

    def forward(self, features, targets=None, current_epoch=None):
        """
        FDS forward pass. Batch-level statistics update is handled in train_utils.py.
        """
        if not self.training:
            global_mean = self.smoothed_mean_last_epoch.mean(dim=0)
            global_var = self.smoothed_var_last_epoch.mean(dim=0)
            return self.calibrate_mean_var(
                features, global_mean, global_var, global_mean, global_var
            )

        if targets is None or current_epoch is None:
            return features

        if current_epoch >= self.start_smooth_epoch:
            targets_np = targets.view(-1).cpu().numpy()
            bucket_indices = self._get_bucket_indices(targets_np)
            device = features.device
            calibrated_features = features.clone()
            
            for i in range(self.num_target_bins):
                mask = torch.from_numpy(bucket_indices == i).to(device)
                if mask.any():
                    calibrated_features[mask] = self.calibrate_mean_var(
                        features[mask],
                        self.running_mean_last_epoch[i],
                        self.running_var_last_epoch[i],
                        self.smoothed_mean_last_epoch[i],
                        self.smoothed_var_last_epoch[i]
                    )
            return calibrated_features

        return features

@deprecated(replacement="FDSLayer")
class FDSLayer_OLD(nn.Module):
    """Legacy FDSLayer. Kept for backward compatibility."""
    
    def __init__(self, feature_dim, fds_config, discretizer):
        super(FDSLayer_OLD, self).__init__()
        self.feature_dim = feature_dim
        self.num_target_bins = fds_config.get('num_target_bins', 50)
        self.fds_momentum = fds_config.get('fds_momentum', 0.9)
        self.start_update_epoch = fds_config.get('start_update_epoch', 0)
        self.discretizer = discretizer
        self.epsilon = 1e-6

        if self.discretizer is None or not hasattr(self.discretizer, 'bin_edges_'):
            raise ValueError("FDSLayer requires a fitted KBinsDiscretizer instance.")
        if self.discretizer.n_bins != self.num_target_bins:
            print(f"Warning: num_target_bins ({self.num_target_bins}) differs from "
                  f"Discretizer n_bins ({self.discretizer.n_bins}). Using Discretizer setting.")
            self.num_target_bins = self.discretizer.n_bins

        self.register_buffer('running_mean', torch.zeros(self.num_target_bins, feature_dim))
        self.register_buffer('running_var', torch.ones(self.num_target_bins, feature_dim))

        print(f"FDSLayer_OLD initialized: feature_dim={feature_dim}, num_bins={self.num_target_bins}, "
              f"momentum={self.fds_momentum}, start_epoch={self.start_update_epoch}")

    def _get_bucket_indices(self, targets_np):
        """Get bucket indices using KBinsDiscretizer."""
        targets_reshaped = targets_np.reshape(-1, 1)
        min_edge = self.discretizer.bin_edges_[0][0]
        max_edge = self.discretizer.bin_edges_[0][-1]
        targets_clipped = np.clip(targets_reshaped, min_edge, max_edge)
        try:
            bucket_indices = self.discretizer.transform(targets_clipped).flatten().astype(int)
        except ValueError as e:
            print(f"Warning: FDS target binning error ({e}). Setting indices to 0.")
            bucket_indices = np.zeros(len(targets_np), dtype=int)
        return bucket_indices

    @torch.no_grad()
    def update_running_stats(self, features, targets_np, current_epoch):
        """Update running statistics."""
        if current_epoch < self.start_update_epoch:
            return

        device = features.device
        bucket_indices = self._get_bucket_indices(targets_np)

        for i in range(self.num_target_bins):
            bin_mask = (bucket_indices == i)
            if np.sum(bin_mask) > 0:
                bin_features = features[torch.from_numpy(bin_mask).to(device)]
                batch_mean = bin_features.mean(dim=0)
                batch_var = bin_features.var(dim=0, unbiased=False)
                self.running_mean[i, :] = self.running_mean[i, :] * self.fds_momentum + batch_mean * (1 - self.fds_momentum)
                self.running_var[i, :] = self.running_var[i, :] * self.fds_momentum + batch_var * (1 - self.fds_momentum)

    def forward(self, features, targets=None, current_epoch=None):
        """Feature calibration."""
        device = features.device

        if self.training and targets is not None and current_epoch is not None:
            targets_np = targets.view(-1).cpu().numpy()
            self.update_running_stats(features, targets_np, current_epoch)

            bucket_indices = self._get_bucket_indices(targets_np)
            bucket_indices_tensor = torch.from_numpy(bucket_indices).long().to(device)

            current_mean_lookup = self.running_mean[bucket_indices_tensor, :]
            current_var_lookup = self.running_var[bucket_indices_tensor, :]

            sigma_running = torch.sqrt(current_var_lookup + self.epsilon)
            mu_running = current_mean_lookup

            calibrated_features = (features - mu_running) / sigma_running
            return calibrated_features

        elif not self.training:
            return features
        else:
            return features