import logging
import os
import pickle
import time
import glob

import numpy as np
import torch
from copy import deepcopy
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau

from .experiment import Experiment
from .utils import get_total_trainable_params, set_seeds
from .utils import cycle

# TabDiff modules - located in experiments/models/tabdiff/
from .models.tabdiff.main_modules import UniModMLP, Model
from .models.tabdiff.unified_ctime_diffusion import UnifiedCtimeDiffusion
from .models.tabdiff.utils import update_ema


class Experiment_TabDiff(Experiment):
    """
    TabDiff experiment implementation adapted from TabDiff paper.
    Uses continuous-time masked diffusion for mixed-type tabular data.
    """
    
    def __init__(self, data_path, sample_path, config, exp_path, dataset, device, preproc, strategy=0, breaks=30000, run_name='', beta='0p7', use_log='switch'):
        super().__init__(data_path, sample_path, config, exp_path, dataset, device, preproc, run_name=run_name, beta=beta, use_log=use_log)
        self.strategy = strategy
        self.breaks = breaks
        
    def train(self, **kwargs):
        """
        Train the TabDiff model with comprehensive training loop.
        """
        save_model = kwargs.get("save_model", False)
        set_seeds(self.seed, cuda_deterministic=True)
        
        # Get train loader (FastTensorDataLoader - same as CDTD, no multiprocessing overhead)
        self.train_loader = self.data_wrangler.get_train_loader(
            self.config.model.batch_size
        )
        X_cat_train = self.train_loader.X_cat
        X_cont_train = self.train_loader.X_cont

        # Get validation loader (for validation and early stopping)
        use_early_stopping = self.config.model.get('use_early_stopping', False)
        if use_early_stopping:
            self.val_loader = self.data_wrangler.get_train_loader(
                self.config.model.batch_size, partition='val'
            )

        # Setup model parameters
        categories = np.array(self.data_wrangler.num_cats)
        self.d_numerical = X_cont_train.shape[1]
        
        # Update config with actual dimensions
        unimodmlp_params = dict(self.config.model.unimodmlp_params)
        unimodmlp_params['d_numerical'] = self.d_numerical
        unimodmlp_params['categories'] = (categories + 1).tolist()  # Add one for mask category
        
        # Build model architecture
        backbone = UniModMLP(**unimodmlp_params)
        model = Model(backbone, **self.config.model.diffusion_params.edm_params)
        model.to(self.device)
      
        # Initialize diffusion model
        diffusion_params = dict(self.config.model.diffusion_params)
        # Extract edm_params for UnifiedCtimeDiffusion (needed for loss calculation)
        edm_params = dict(diffusion_params.pop('edm_params', {}))
        
        self.diffusion = UnifiedCtimeDiffusion(
            num_classes=categories,
            num_numerical_features=self.d_numerical,
            denoise_fn=model,
            y_only_model=None,
            edm_params=edm_params,  # Pass edm_params for _edm_loss
            strategy=self.strategy,
            breaks=self.breaks,
            **diffusion_params,
            device=self.device,
        )
        
        num_params = get_total_trainable_params(self.diffusion)
        print(f"Total parameters = {int(num_params):,}")
        
        self.diffusion.to(self.device)
        self.diffusion.train()

        # Setup EMA models
        self.ema_model = deepcopy(self.diffusion._denoise_fn)
        for param in self.ema_model.parameters():
            param.detach_()
        self.ema_num_schedule = deepcopy(self.diffusion.num_schedule)
        for param in self.ema_num_schedule.parameters():
            param.detach_()
        self.ema_cat_schedule = deepcopy(self.diffusion.cat_schedule)
        for param in self.ema_cat_schedule.parameters():
            param.detach_()

        # Setup optimizer and scheduler
        self.optimizer = torch.optim.AdamW(
            self.diffusion.parameters(),
            lr=self.config.model.lr,
            weight_decay=self.config.model.get('weight_decay', 0.0),
            betas=self.config.model.get('betas', [0.9, 0.999])
        )
        
        # Learning rate scheduler
        lr_scheduler_type = self.config.model.get('lr_scheduler', 'reduce_lr_on_plateau')
        if lr_scheduler_type == 'reduce_lr_on_plateau':
            self.scheduler = ReduceLROnPlateau(
                self.optimizer,
                mode='min',
                factor=self.config.model.get('lr_factor', 0.9),
                patience=self.config.model.get('lr_patience', 100),
                min_lr=self.config.model.get('min_lr', 1e-6)
            )
        else:
            self.scheduler = None

        # Training parameters
        train_steps = self.config.model.train_steps
        log_steps = self.config.model.get('log_steps', 100)
        ema_decay = self.config.model.get('ema_decay', 0.997)
        
        # Validation and early stopping parameters
        validation_interval = None
        max_patience = None
        patience = 0
        best_val_loss = float('inf')
        best_step = 0
        
        if use_early_stopping:
            # Handle None/null from config - use default if explicitly set to None
            validation_interval = self.config.model.get('validation_interval', train_steps // 100)
            if validation_interval is None:
                validation_interval = train_steps // 100  # Default: validate every 1% of steps
            
            max_patience = self.config.model.get('max_patience', None)
            if max_patience is None:
                # Default patience: 15% of total validation rounds
                total_validation_rounds = train_steps // validation_interval
                max_patience = max(1, int(total_validation_rounds * 0.15))
            logging.warning(f'Early stopping enabled: validation_interval={validation_interval}, max_patience={max_patience}')
        
        # Training loop
        step = 0
        training_start_time = time.time()
        
        # Cycle through the dataloader indefinitely for training (same as CDTD)
        train_iter = cycle(self.train_loader)
        
        pbar = tqdm(total=train_steps, desc="Training", disable=not save_model)
        
        while step < train_steps:
            self.optimizer.zero_grad()
            
            # Get batch from FastTensorDataLoader (returns tuple: x_cat, x_cont, m_cat, m_cont, y)
            batch_tuple = next(train_iter)
            x_cat_batch = batch_tuple[0].to(self.device) if batch_tuple[0] is not None else None
            x_cont_batch = batch_tuple[1].to(self.device) if batch_tuple[1] is not None else None
            m_cat_batch = batch_tuple[2].to(self.device) if batch_tuple[2] is not None else None
            m_cont_batch = batch_tuple[3].to(self.device) if batch_tuple[3] is not None else None
            
            # Concatenate on-the-fly: TabDiff expects [num_features, cat_features] format
            if x_cont_batch is not None and x_cat_batch is not None:
                x = torch.cat((x_cont_batch, x_cat_batch), dim=1)
            elif x_cont_batch is not None:
                x = x_cont_batch
            elif x_cat_batch is not None:
                x = x_cat_batch
            else:
                raise ValueError("Both x_cont and x_cat are None")
            
            # Compute loss (pass masks for strategy-based masking)
            d_loss, c_loss = self.diffusion.mixed_loss(x, m_cat=m_cat_batch, m_cont=m_cont_batch)
            total_loss = d_loss + c_loss
            
            # Backward pass
            total_loss.backward()
            
            # Gradient clipping (optional)
            if self.config.model.get('clip_grad', False):
                torch.nn.utils.clip_grad_norm_(
                    self.diffusion.parameters(),
                    self.config.model.get('max_grad_norm', 1.0)
                )
            
            self.optimizer.step()
            
            # Update EMA models
            update_ema(self.ema_model.parameters(), self.diffusion._denoise_fn.parameters(), rate=ema_decay)
            update_ema(self.ema_num_schedule.parameters(), self.diffusion.num_schedule.parameters(), rate=ema_decay)
            update_ema(self.ema_cat_schedule.parameters(), self.diffusion.cat_schedule.parameters(), rate=ema_decay)
            
            # Logging
            if step % log_steps == 0:
                pbar.set_description(
                    f"Step {step}/{train_steps} | Loss: {total_loss.item():.4f} "
                    f"(D: {d_loss.item():.4f}, C: {c_loss.item():.4f})"
                )
                if self.scheduler is not None:
                    self.scheduler.step(total_loss.item())
            
            # Validation and early stopping
            if use_early_stopping and (step + 1) % validation_interval == 0:
                val_losses = []
                self.diffusion.eval()
                with torch.no_grad():
                    for val_batch_tuple in self.val_loader:
                        x_cat_val = val_batch_tuple[0].to(self.device) if val_batch_tuple[0] is not None else None
                        x_cont_val = val_batch_tuple[1].to(self.device) if val_batch_tuple[1] is not None else None
                        m_cat_val = val_batch_tuple[2].to(self.device) if val_batch_tuple[2] is not None else None
                        m_cont_val = val_batch_tuple[3].to(self.device) if val_batch_tuple[3] is not None else None
                        
                        # Concatenate validation batch (same as training)
                        if x_cont_val is not None and x_cat_val is not None:
                            x_val = torch.cat((x_cont_val, x_cat_val), dim=1)
                        elif x_cont_val is not None:
                            x_val = x_cont_val
                        elif x_cat_val is not None:
                            x_val = x_cat_val
                        else:
                            continue
                        
                        # Compute validation loss (pass masks for strategy-based masking)
                        d_loss_val, c_loss_val = self.diffusion.mixed_loss(x_val, m_cat=m_cat_val, m_cont=m_cont_val)
                        val_loss = (d_loss_val + c_loss_val).item()
                        val_losses.append(val_loss)
                        
                        # Only evaluate a few batches for efficiency
                        if len(val_losses) >= 10:
                            break
                
                self.diffusion.train()
                
                if len(val_losses) > 0:
                    current_val_loss = np.mean(val_losses)
                    pbar.set_description(
                        f"Step {step}/{train_steps} | Loss: {total_loss.item():.4f} | "
                        f"Val: {current_val_loss:.4f} | Patience: {patience}/{max_patience}"
                    )
                    
                    if current_val_loss < best_val_loss:
                        patience = 0
                        best_val_loss = current_val_loss
                        best_step = step + 1
                        # Save best model
                        self._copy_ema_to_model()
                        if save_model:
                            self.save_model()
                        logging.warning(f'[Step {best_step}] New best validation loss: {best_val_loss:.4f}')
                    else:
                        patience += 1
                    
                    # Early stopping
                    if patience >= max_patience:
                        logging.warning(f'Early stopping at step {step + 1} (patience {patience}/{max_patience})')
                        logging.warning(f'Best validation loss: {best_val_loss:.4f} at step {best_step}')
                        break
            
            step += 1
            pbar.update(1)
        
        pbar.close()
        
        # Copy EMA weights to model for saving (if not already done by early stopping)
        if not (use_early_stopping and patience >= max_patience):
            self._copy_ema_to_model()
        
        training_duration = time.time() - training_start_time
        if save_model: 
            self.save_train_time(training_duration)
            if not use_early_stopping:
                # Only save if not using early stopping (best model already saved during validation)
                self.save_model()
            else:
                logging.warning(f'Best model (validation loss: {best_val_loss:.4f}) already saved at step {best_step}')
    
    def _copy_ema_to_model(self):
        """Copy EMA weights to the main model."""
        self.diffusion._denoise_fn.load_state_dict(self.ema_model.state_dict())
        self.diffusion.num_schedule.load_state_dict(self.ema_num_schedule.state_dict())
        self.diffusion.cat_schedule.load_state_dict(self.ema_cat_schedule.state_dict())
    
    def save_model(self):
        """
        Save model checkpoints including EMA models.
        """
        os.makedirs(self.ckpt_restore_dir, exist_ok=True)
        
        # Save main model
        state_dicts = {
            'denoise_fn': self.diffusion._denoise_fn.state_dict(),
            'num_schedule': self.diffusion.num_schedule.state_dict(),
            'cat_schedule': self.diffusion.cat_schedule.state_dict(),
        }
        torch.save(state_dicts, os.path.join(self.ckpt_restore_dir, "model.pt"))
        
        # Save EMA model (best model)
        ema_state_dicts = {
            'denoise_fn': self.ema_model.state_dict(),
            'num_schedule': self.ema_num_schedule.state_dict(),
            'cat_schedule': self.ema_cat_schedule.state_dict(),
        }
        torch.save(ema_state_dicts, os.path.join(self.ckpt_restore_dir, "ema_model.pt"))
    
    def load_model(self):
        """
        Load the best EMA model checkpoint.
        """
        set_seeds(self.seed)
        
        # Get data dimensions from train loader
        self.train_loader = self.data_wrangler.get_train_loader(
            self.config.model.batch_size
        )
        X_cont_train = self.train_loader.X_cont
        
        categories = np.array(self.data_wrangler.num_cats)
        self.d_numerical = X_cont_train.shape[1]
        
        # Rebuild model architecture
        unimodmlp_params = dict(self.config.model.unimodmlp_params)
        unimodmlp_params['d_numerical'] = self.d_numerical
        unimodmlp_params['categories'] = (categories + 1).tolist()
        
        backbone = UniModMLP(**unimodmlp_params)
        model = Model(backbone, **self.config.model.diffusion_params.edm_params)
        model.to(self.device)
      
        diffusion_params = dict(self.config.model.diffusion_params)
        diffusion_params.pop('edm_params', None)
        
        self.diffusion = UnifiedCtimeDiffusion(
            num_classes=categories,
            num_numerical_features=self.d_numerical,
            denoise_fn=model,
            y_only_model=None,
            **diffusion_params,
            device=self.device,
        )

        # Load best EMA model checkpoint (prefer EMA, fallback to regular)
        ema_checkpoint_path = os.path.join(self.ckpt_restore_dir, "ema_model.pt")
        checkpoint_path = os.path.join(self.ckpt_restore_dir, "model.pt")
        
        if os.path.exists(ema_checkpoint_path):
            state_dicts = torch.load(ema_checkpoint_path, map_location=self.device)
            print(f"Loading EMA model from {ema_checkpoint_path}")
        elif os.path.exists(checkpoint_path):
            state_dicts = torch.load(checkpoint_path, map_location=self.device)
            print(f"Loading model from {checkpoint_path}")
        else:
            # Try to find best_ema_model_* pattern (if using trainer)
            ema_files = glob.glob(os.path.join(self.ckpt_restore_dir, "best_ema_model_*"))
            if ema_files:
                state_dicts = torch.load(ema_files[0], map_location=self.device)
                print(f"Loading best EMA model from {ema_files[0]}")
            else:
                raise FileNotFoundError(f"No model checkpoint found in {self.ckpt_restore_dir}")
        
        self.diffusion._denoise_fn.load_state_dict(state_dicts['denoise_fn'])
        if 'num_schedule' in state_dicts:
            self.diffusion.num_schedule.load_state_dict(state_dicts['num_schedule'])
        if 'cat_schedule' in state_dicts:
            self.diffusion.cat_schedule.load_state_dict(state_dicts['cat_schedule'])
        
        self.diffusion.to(self.device)
        self.diffusion.eval()
            
    def sample_tabular_data(self, num_samples, **kwargs):
        """
        Generate synthetic samples.
        
        Args:
            num_samples: Number of samples to generate
            **kwargs: Additional arguments (seed, verbose, etc.)
        
        Returns:
            X_cat_gen: Generated categorical features (numpy array)
            X_cont_gen: Generated continuous features (numpy array)
            y_gen: Generated target variable (None for unconditional generation)
        """
        seed = kwargs.get("seed", None)
        verbose = kwargs.get("verbose", False)
        batch_size = min(self.config.model.get('sample_batch_size', 4096), num_samples)
        
        if seed is not None:
            set_seeds(seed, cuda_deterministic=True)
        
        with torch.no_grad():
            syn_data = self.diffusion.sample_all(
                num_samples, 
                batch_size, 
                keep_nan_samples=True, 
                verbose=verbose
            )
        
        # Check for NaN rows (all-zero rows indicate NaN instances)
        num_all_zero_row = (syn_data.sum(dim=1) == 0).sum().item()
        if num_all_zero_row:
            print(f"Warning: Generated samples contain {num_all_zero_row} NaN instances (all-zero rows)")
    
        # Split into categorical and continuous features
        X_cont_gen = syn_data[:, :self.d_numerical].cpu().numpy()
        X_cat_gen = syn_data[:, self.d_numerical:].cpu().numpy()

        # Postprocess using data_wrangler
        X_cat_gen, X_cont_gen, y_gen = self.data_wrangler.postprocess_gen_data(
            X_cat_gen.astype(np.int64),
            X_cont_gen.astype(np.float64),
            None,  # y_gen is None for unconditional generation
        )

        return X_cat_gen, X_cont_gen, y_gen

