"""Tabular data sampler using MLP SCM prior from TabICL."""

import sys
from pathlib import Path
from typing import Optional, Tuple

import numpy as np
import torch

# Add tabicl to path
sys.path.append(str(Path(__file__).parent.parent.parent / "tabicl" / "src"))

from tabicl.prior.mlp_scm import MLPSCM
from src.data.preprocess import TabICLScaler
from src.utils import DataAttr


class TabularSampler:
    """Generate tabular regression data using MLP SCM prior.
    
    This sampler generates synthetic tabular regression functions using
    the MLP-based Structural Causal Model from TabICL.
    """
    
    def __init__(
        self,
        dim_x: int | list = 10,
        dim_y: int = 1,
        # MLP SCM parameters
        is_causal: bool = True,
        num_causes: Optional[int] = None,
        num_layers: int = 4,
        hidden_dim: int = 64,
        noise_std: float = 0.01,
        sampling: str = "mixed",
        # Normalization
        normalize_y: bool = True,
        normalize_x: bool = False,
        x_norm_method: str = "power",
        x_outlier_threshold: float = 4.0,
        # Other
        device: str = "cpu",
        dtype: torch.dtype = torch.float32,
        **kwargs,
    ):
        """Initialize tabular sampler.
        
        Args:
            dim_x: Number of input features (int or list of ints to sample from)
            dim_y: Number of output dimensions (always 1 for regression)
            is_causal: Whether to use causal generation
            num_causes: Number of root causes (if None, defaults to dim_x//2)
            num_layers: Number of MLP layers in SCM
            hidden_dim: Hidden dimension of MLP
            noise_std: Noise standard deviation
            sampling: Sampling method for initial causes ("normal", "uniform", "mixed")
            normalize_y: Whether to z-normalize y values using context statistics
            device: Device to generate on
            dtype: Data type
        """
        # Handle dim_x as list or int
        if isinstance(dim_x, int):
            self.dim_x_list = [dim_x]
        else:
            self.dim_x_list = dim_x
        
        self.dim_y = dim_y
        self.device = device
        self.dtype = dtype
        self.normalize_y = normalize_y
        self.normalize_x = normalize_x
        self.x_norm_method = x_norm_method
        self.x_outlier_threshold = x_outlier_threshold
        
        # MLP SCM parameters
        self.is_causal = is_causal
        self.num_causes = num_causes
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.noise_std = noise_std
        self.sampling = sampling
        
    def _generate_function(self, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate a single regression function."""
        
        # Randomly choose dim_x from list
        dim_x = np.random.choice(self.dim_x_list)
        
        # Determine num_causes robustly for small dim_x
        if self.num_causes is None:
            base_num_causes = max(1, dim_x // 2)
        else:
            base_num_causes = int(np.clip(self.num_causes, 1, dim_x))
        # Sample within [1, dim_x], centered near base
        lo = max(1, base_num_causes - 2)
        hi = min(dim_x, base_num_causes + 3)
        if hi < lo:
            lo, hi = 1, dim_x
        actual_num_causes = int(np.random.randint(lo, hi + 1))
        actual_num_layers = np.random.randint(
            max(2, self.num_layers - 1),
            self.num_layers + 2
        )
        actual_hidden_dim = np.random.randint(
            max(16, self.hidden_dim - 16),
            self.hidden_dim + 32
        )
        
        # Create MLP SCM model
        model = MLPSCM(
            seq_len=num_samples,
            num_features=dim_x,  # Use the chosen dim_x
            num_outputs=self.dim_y,
            is_causal=self.is_causal,
            num_causes=actual_num_causes,
            y_is_effect=True,
            in_clique=False,
            sort_features=True,
            num_layers=actual_num_layers,
            hidden_dim=actual_hidden_dim,
            mlp_activations=torch.nn.Tanh,
            init_std=np.random.uniform(0.8, 2.0),
            block_wise_dropout=True,
            mlp_dropout_prob=np.random.uniform(0.05, 0.2),
            scale_init_std_by_dropout=True,
            sampling=self.sampling,
            pre_sample_cause_stats=True,
            noise_std=self.noise_std,
            pre_sample_noise_std=True,
            device=self.device,
        )
        
        # Generate data
        with torch.no_grad():
            X, y = model()
        
        # Ensure correct shape
        if y.dim() == 1:
            y = y.unsqueeze(-1)
            
        return X.to(self.dtype), y.to(self.dtype)
    
    def generate_batch(
        self,
        batch_size: int,
        num_context: Optional[int | list] = None,
        num_buffer: int = 0,
        num_target: int = 128,
        context_range: Optional[Tuple[int, int]] = None,
        **kwargs,
    ) -> DataAttr:
        """Generate a batch of tabular regression tasks.
        
        All tasks in the batch will have the same feature dimension and context size
        to avoid the need for padding.
        
        Args:
            batch_size: Number of independent tasks
            num_context: Number of context points (int for fixed, list to sample from, None for range)
            num_buffer: Number of buffer points (fixed)
            num_target: Number of target points (fixed)
            context_range: Range for random context size if num_context is None
            
        Returns:
            DataAttr with xc, yc, xb, yb, xt, yt
        """
        # Choose dimensions ONCE for the entire batch
        # 1. Choose feature dimension
        dim_x = np.random.choice(self.dim_x_list)
        
        # 2. Choose context size
        if num_context is None:
            if context_range is None:
                context_range = (32, 256)
            nc = np.random.randint(context_range[0], context_range[1] + 1)
        elif isinstance(num_context, int):
            nc = num_context
        else:
            nc = np.random.choice(num_context)
            
        # Buffer and target are fixed
        nb = num_buffer
        nt = num_target
        
        # Total samples needed per task
        total_samples = nc + nb + nt
        
        xc_list, yc_list = [], []
        xb_list, yb_list = [], []
        xt_list, yt_list = [], []
        
        for _ in range(batch_size):
            # Generate a regression function with fixed dim_x
            X, y = self._generate_function_fixed_dim(total_samples, dim_x)
            
            # Shuffle the data
            perm = torch.randperm(total_samples)
            X = X[perm]
            y = y[perm]
            
            # Split into context, buffer, target
            xc = X[:nc]
            yc = y[:nc]
            
            xb = X[nc:nc + nb]
            yb = y[nc:nc + nb]
            
            xt = X[nc + nb:]
            yt = y[nc + nb:]
            
            # Apply scaling
            if self.normalize_x or self.normalize_y:
                # Fit scaler on context; include y if normalize_y
                scaler = TabICLScaler(
                    normalization_method=self.x_norm_method,
                    outlier_threshold=self.x_outlier_threshold,
                    random_state=None,
                )
                # Fit on context (features always; targets if requested)
                scaler.fit(xc.cpu().numpy(), yc.cpu().numpy() if self.normalize_y else None)
                # Transform entire sample (context, buffer, target)
                from src.utils import DataAttr
                sample = DataAttr(xc=xc, yc=yc, xb=xb, yb=yb, xt=xt, yt=yt)
                sample_scaled = scaler.transform_batch(sample)
                xc, yc, xb, yb, xt, yt = (
                    sample_scaled.xc.to(self.dtype),
                    sample_scaled.yc.to(self.dtype) if sample_scaled.yc is not None else yc,
                    sample_scaled.xb.to(self.dtype),
                    sample_scaled.yb.to(self.dtype) if sample_scaled.yb is not None else yb,
                    sample_scaled.xt.to(self.dtype),
                    sample_scaled.yt.to(self.dtype) if sample_scaled.yt is not None else yt,
                )
            
            xc_list.append(xc)
            yc_list.append(yc)
            xb_list.append(xb)
            yb_list.append(yb)
            xt_list.append(xt)
            yt_list.append(yt)
        
        # Stack into batches - no padding needed!
        return DataAttr(
            xc=torch.stack(xc_list),
            yc=torch.stack(yc_list),
            xb=torch.stack(xb_list) if nb > 0 else torch.zeros(batch_size, 0, dim_x, device=self.device, dtype=self.dtype),
            yb=torch.stack(yb_list) if nb > 0 else torch.zeros(batch_size, 0, self.dim_y, device=self.device, dtype=self.dtype),
            xt=torch.stack(xt_list),
            yt=torch.stack(yt_list),
        )
    
    def _generate_function_fixed_dim(self, num_samples: int, dim_x: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate a single regression function with fixed dimensionality."""
        
        # Determine num_causes robustly for small dim_x
        if self.num_causes is None:
            base_num_causes = max(1, dim_x // 2)
        else:
            base_num_causes = int(np.clip(self.num_causes, 1, dim_x))
        lo = max(1, base_num_causes - 2)
        hi = min(dim_x, base_num_causes + 3)
        if hi < lo:
            lo, hi = 1, dim_x
        actual_num_causes = int(np.random.randint(lo, hi + 1))
        actual_num_layers = np.random.randint(
            max(2, self.num_layers - 1),
            self.num_layers + 2
        )
        actual_hidden_dim = np.random.randint(
            max(16, self.hidden_dim - 16),
            self.hidden_dim + 32
        )
        
        # Create MLP SCM model
        model = MLPSCM(
            seq_len=num_samples,
            num_features=dim_x,  # Use the fixed dim_x
            num_outputs=self.dim_y,
            is_causal=self.is_causal,
            num_causes=actual_num_causes,
            y_is_effect=True,
            in_clique=False,
            sort_features=True,
            num_layers=actual_num_layers,
            hidden_dim=actual_hidden_dim,
            mlp_activations=torch.nn.Tanh,
            init_std=np.random.uniform(0.8, 2.0),
            block_wise_dropout=True,
            mlp_dropout_prob=np.random.uniform(0.05, 0.2),
            scale_init_std_by_dropout=True,
            sampling=self.sampling,
            pre_sample_cause_stats=True,
            noise_std=self.noise_std,
            pre_sample_noise_std=True,
            device=self.device,
        )
        
        # Generate data
        with torch.no_grad():
            X, y = model()
        
        # Ensure correct shape
        if y.dim() == 1:
            y = y.unsqueeze(-1)
            
        return X.to(self.dtype), y.to(self.dtype)
