import logging
import os
import time

import numpy as np
import torch
from copy import deepcopy
from tqdm import tqdm

from .experiment import Experiment
from .utils import get_total_trainable_params, set_seeds
from .utils import cycle

# TabDDPM modules - located in experiments/models/tabddpm/
from .models.tabddpm.gaussian_multinomial_diffusion import GaussianMultinomialDiffusion


class Experiment_TabDDPM(Experiment):
    """
    TabDDPM experiment implementation based on synthcity.
    Uses Gaussian-Multinomial diffusion for mixed-type tabular data.
    """
    
    def __init__(self, data_path, sample_path, config, exp_path, dataset, device, preproc, strategy=0, breaks=30000, run_name='', beta='0p7', use_log='switch'):
        super().__init__(data_path, sample_path, config, exp_path, dataset, device, preproc, run_name=run_name, beta=beta, use_log=use_log)
        self.strategy = strategy
        self.breaks = breaks
        
    def train(self, **kwargs):
        """
        Train the TabDDPM model.
        """
        save_model = kwargs.get("save_model", False)
        set_seeds(self.seed, cuda_deterministic=True)
        
        # Get train loader (FastTensorDataLoader - same as CDTD, no multiprocessing overhead)
        self.train_loader = self.data_wrangler.get_train_loader(
            self.config.model.batch_size
        )
        X_cont_train = self.train_loader.X_cont
        X_cat_train = self.train_loader.X_cat
        
        # Setup model parameters
        # Add 1 to each category to account for unknown value remapping
        # Unknown values (9999) are remapped to n_cats, so range is [0, n_cats] requiring n_cats+1 classes
        categories = tuple(c + 1 for c in self.data_wrangler.num_cats)
        self.n_num_feat = X_cont_train.shape[1]
        
        # DEBUG: Check what the model expects vs what data actually has
        if len(categories) > 8:
            print(f"\n[DEBUG Column 8] Model initialization:")
            print(f"  num_cats[8] = {self.data_wrangler.num_cats[8]}")
            print(f"  categories[8] = {categories[8]} (model expects indices [0, {categories[8]-1}])")
            print(f"  Actual training data (X_cat_train[:, 8]) unique values: {np.unique(X_cat_train[:, 8])}")
            print(f"  Actual training data min: {X_cat_train[:, 8].min()}, max: {X_cat_train[:, 8].max()}")
            if X_cat_train[:, 8].max() >= categories[8]:
                print(f"  ⚠️  WARNING: Training data has values >= {categories[8]}!")
        
        # Initialize diffusion model
        # Convert OmegaConf to dict for model_params (needed for attribute access in MLP)
        model_params_dict = dict(self.config.model.model_params)
        
        self.diffusion = GaussianMultinomialDiffusion(
            num_numerical_features=self.n_num_feat,
            num_categorical_features=categories,
            model_params=type('obj', (object,), model_params_dict),  # Convert dict to object for attribute access
            num_timesteps=self.config.model.num_timesteps,
            dim_emb=self.config.model.dim_emb,
            gaussian_loss_type=self.config.model.gaussian_loss_type,
            scheduler=self.config.model.scheduler,
            strategy=self.strategy,
            breaks=self.breaks,
            device=self.device,
        ).to(self.device)
        
        num_params = get_total_trainable_params(self.diffusion)
        print(f"Total parameters = {int(num_params):,}")
        
        self.diffusion.train()
        
        # Setup EMA model
        self.ema_model = deepcopy(self.diffusion.denoise_fn)
        for param in self.ema_model.parameters():
            param.detach_()
        
        # Setup optimizer
        self.optimizer = torch.optim.AdamW(
            self.diffusion.parameters(),
            lr=self.config.model.lr,
            weight_decay=self.config.model.weight_decay
        )
        
        # Training parameters
        train_steps = self.config.model.train_steps
        log_steps = self.config.model.get('log_steps', 100)
        ema_decay = self.config.model.get('ema_decay', 0.999)
        
        # Training loop
        step = 0
        curr_loss_multi = 0.0
        curr_loss_gauss = 0.0
        curr_count = 0
        training_start_time = time.time()
        
        # Cycle through the dataloader indefinitely for training (same as CDTD)
        train_iter = cycle(self.train_loader)
        
        pbar = tqdm(total=train_steps, desc="Training", disable=not save_model)
        
        while step < train_steps:
            self.optimizer.zero_grad()
            
            # Get batch from FastTensorDataLoader (returns tuple: x_cat, x_cont, m_cat, m_cont, y)
            batch_tuple = next(train_iter)
            x_cat_batch = batch_tuple[0].to(self.device) if batch_tuple[0] is not None else None
            x_cont_batch = batch_tuple[1].to(self.device) if batch_tuple[1] is not None else None
            # Note: TabDDPM doesn't support masks yet, but we extract them for future use
            m_cat_batch = batch_tuple[2].to(self.device) if batch_tuple[2] is not None else None
            m_cont_batch = batch_tuple[3].to(self.device) if batch_tuple[3] is not None else None
            
            # Concatenate: TabDDPM expects [num_features, cat_features] format
            # Use column_stack to match original TabDDPM behavior
            if x_cont_batch is not None and x_cat_batch is not None:
                x = torch.column_stack((x_cont_batch, x_cat_batch))
            elif x_cont_batch is not None:
                x = x_cont_batch
            elif x_cat_batch is not None:
                x = x_cat_batch
            else:
                raise ValueError("Both x_cont and x_cat are None")
            
            # Compute loss (pass masks for strategy-based masking)
            loss_multi, loss_gauss = self.diffusion.mixed_loss(x, cond=None, m_cat=m_cat_batch, m_cont=m_cont_batch)
            loss = loss_multi + loss_gauss
            
            # Backward pass
            loss.backward()
            self.optimizer.step()
            
            # Update EMA
            self._update_ema(
                self.ema_model.parameters(),
                self.diffusion.denoise_fn.parameters(),
                rate=ema_decay
            )
            
            # Track losses
            curr_count += len(x)
            curr_loss_multi += loss_multi.item() * len(x)
            curr_loss_gauss += loss_gauss.item() * len(x)
            
            # Learning rate annealing
            self._anneal_lr(step, train_steps)
            
            # Logging
            if step % log_steps == 0:
                mloss = np.around(curr_loss_multi / curr_count, 4)
                gloss = np.around(curr_loss_gauss / curr_count, 4)
                total_loss = mloss + gloss
                pbar.set_description(
                    f"Step {step}/{train_steps} | Loss: {total_loss:.4f} "
                    f"(Multi: {mloss:.4f}, Gauss: {gloss:.4f})"
                )
                curr_count = 0
                curr_loss_gauss = 0.0
                curr_loss_multi = 0.0
            
            step += 1
            pbar.update(1)
        
        pbar.close()
        
        # Copy EMA weights to model for saving
        self._copy_ema_to_model()
        
        training_duration = time.time() - training_start_time
        if save_model:
            self.save_train_time(training_duration)
            self.save_model()
    
    def _anneal_lr(self, step: int, total_steps: int) -> None:
        """Linear learning rate annealing."""
        frac_done = step / total_steps
        lr = self.config.model.lr * (1 - frac_done)
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

    def _update_ema(self, target_params, source_params, rate=0.999):
        """
        Update target parameters to be closer to those of source parameters using
        an exponential moving average.
        :param target_params: the target parameter sequence.
        :param source_params: the source parameter sequence.
        :param rate: the EMA rate (closer to 1 means slower).
        """
        for targ, src in zip(target_params, source_params):
            targ.detach().mul_(rate).add_(src.detach(), alpha=1 - rate)
    
    def _copy_ema_to_model(self):
        """Copy EMA weights to the main model."""
        self.diffusion.denoise_fn.load_state_dict(self.ema_model.state_dict())
    
    def save_model(self):
        """
        Save model checkpoints including EMA models.
        """
        os.makedirs(self.ckpt_restore_dir, exist_ok=True)
        
        # Save main model
        torch.save(
            self.diffusion.state_dict(),
            os.path.join(self.ckpt_restore_dir, "model.pt")
        )
        torch.save(
            self.ema_model.state_dict(),
            os.path.join(self.ckpt_restore_dir, "ema_model.pt")
        )
    
    def load_model(self):
        """
        Load the best EMA model checkpoint.
        """
        set_seeds(self.seed)
        
        # Get data dimensions from train loader
        self.train_loader = self.data_wrangler.get_train_loader(
            self.config.model.batch_size
        )
        X_cont_train = self.train_loader.X_cont
        
        # Add 1 to each category to account for unknown value remapping
        # Unknown values (9999) are remapped to n_cats, so range is [0, n_cats] requiring n_cats+1 classes
        categories = tuple(c + 1 for c in self.data_wrangler.num_cats)
        self.n_num_feat = X_cont_train.shape[1]
        
        # Rebuild model architecture
        # Convert OmegaConf to dict for model_params (needed for attribute access in MLP)
        model_params_dict = dict(self.config.model.model_params)
        
        self.diffusion = GaussianMultinomialDiffusion(
            num_numerical_features=self.n_num_feat,
            num_categorical_features=categories,
            model_params=type('obj', (object,), model_params_dict),  # Convert dict to object for attribute access
            num_timesteps=self.config.model.num_timesteps,
            dim_emb=self.config.model.dim_emb,
            gaussian_loss_type=self.config.model.gaussian_loss_type,
            scheduler=self.config.model.scheduler,
            strategy=self.strategy,
            breaks=self.breaks,
            device=self.device,
        )
        
        # Load best EMA model checkpoint (prefer EMA, fallback to regular)
        ema_checkpoint_path = os.path.join(self.ckpt_restore_dir, "ema_model.pt")
        checkpoint_path = os.path.join(self.ckpt_restore_dir, "model.pt")
        
        if os.path.exists(ema_checkpoint_path):
            state_dict = torch.load(ema_checkpoint_path, map_location=self.device)
            logging.info(f"Loading EMA model from {ema_checkpoint_path}")
            # Load EMA state dict into denoise_fn
            self.diffusion.denoise_fn.load_state_dict(state_dict)
        elif os.path.exists(checkpoint_path):
            state_dict = torch.load(checkpoint_path, map_location=self.device)
            logging.info(f"Loading model from {checkpoint_path}")
            self.diffusion.load_state_dict(state_dict)
        else:
            raise FileNotFoundError(f"No model checkpoint found in {self.ckpt_restore_dir}")
        
        self.diffusion.to(self.device)
        self.diffusion.eval()
            
    def sample_tabular_data(self, num_samples, **kwargs):
        """
        Generate synthetic samples.
        
        Args:
            num_samples: Number of samples to generate
            **kwargs: Additional arguments (seed, verbose, etc.)
        
        Returns:
            X_cat_gen: Generated categorical features (numpy array)
            X_cont_gen: Generated continuous features (numpy array)
            y_gen: Generated target variable (None for unconditional generation)
        """
        seed = kwargs.get("seed", None)
        verbose = kwargs.get("verbose", False)
        batch_size = min(self.config.model.get('sample_batch_size', 2000), num_samples)
        
        if seed is not None:
            set_seeds(seed, cuda_deterministic=True)
        
        with torch.no_grad():
            syn_data = self.diffusion.sample_all(
                num_samples,
                cond=None,
                max_batch_size=batch_size,
                ddim=False
            )
        
        # Check for NaN rows
        num_all_zero_row = (syn_data.sum(dim=1) == 0).sum().item()
        if num_all_zero_row:
            print(f"Warning: Generated samples contain {num_all_zero_row} NaN instances (all-zero rows)")
    
        # Split into categorical and continuous features
        # NOTE: TabDDPM uses [num_features, cat_features] order
        X_cont_gen = syn_data[:, :self.n_num_feat].cpu().numpy()
        X_cat_gen = syn_data[:, self.n_num_feat:].cpu().numpy()

        # Postprocess using data_wrangler
        X_cat_gen, X_cont_gen, y_gen = self.data_wrangler.postprocess_gen_data(
            X_cat_gen.astype(np.int64),
            X_cont_gen.astype(np.float64),
            None,  # y_gen is None for unconditional generation
        )

        return X_cat_gen, X_cont_gen, y_gen

