import torch
from dataset.dataset_local import DataTransformer, ParquetDataset, process_datasets, DataLoader, collate_fn
# Import the LMDB dataset loader
from dataset.dataset_lmdb import Dataset_lmdb
from vectorizer.columnVectorizer import *
from vectorizer.TableVectorizer import TableVectorizer
from latent.vae.perceive.perceive import LatentAutoEncoder, MultiModalLatentAutoEncoder, SimpleAutoEncoder, DisentangledMultiModalLatentAutoEncoder
from latent.vae.perceive.trainer import VAETrainer
import os
import tempfile
import shutil
import json
import pandas as pd
import math


#import wandb


DEBUGGING = False



class TableLatentModel:
    def __init__(
        self,
        d_lm=1024,
        d_latent_len=16,
        d_latent_width=64,
        max_n_cols=100,
        output_dim=1024,
        encoder_depth=2,
        encoder_dim_head=64,
        encoder_ff_mult=4,
        fuse_option="flatten",
        decoder_depth=2,
        decoder_num_heads=8,
        vectorizer_mapping=None,
        vectorizer_configs=None,
        numerical_transformation="ple",
        device="cuda",
        autoencoder_type="unimodal",
        combination_method="mopoe",
        init_wandb=False
    ):
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.d_lm = d_lm
        self.is_fitted = False
        
        self.numerical_transformation = numerical_transformation
        
        # Default vectorizer mapping if not provided
        if vectorizer_mapping is None:
            vectorizer_mapping = {
                "numerical": QuantileEmbeddingVectorizer if numerical_transformation == "quantile" else PLEVectorizer,
                "categorical": CategoricalVectorizer,
                "text": TextVectorizer,
                "datetime": DateTimeVectorizer,
            }
        
        # Default vectorizer configs if not provided
        if vectorizer_configs is None:
            vectorizer_configs = {
                "numerical": {
                    "output_dim":1024,
                    "input_dim": 32
                },
                "categorical": {
                    "model_name": "Alibaba-NLP/gte-large-en-v1.5",
                    "projection_dim": 128
                },
                "text": {},  # Add default text configs if needed
                "datetime": {}  # Add default datetime configs if needed
            }
        
        # Initialize TableVectorizer with configs
        self.table_vectorizer = TableVectorizer(
            transformer_mapping=vectorizer_mapping,
            output_dim=output_dim,
            transformer_configs=vectorizer_configs
        )
        
        # Define model parameters
        self.encoder_params = {
            'dim': d_lm,
            'dim_latent': d_latent_width,
            'depth': encoder_depth,
            'dim_head': encoder_dim_head,
            'num_latents': d_latent_len,
            'max_seq_len': max_n_cols,
            'ff_mult': encoder_ff_mult,
            'fuse_option': fuse_option,
        }
        self.decoder_params = {
            'lm_emb': d_lm,
            'aggregated_dim': d_latent_width,
            "output_dim": output_dim,
            "depth": decoder_depth,
            'num_heads': decoder_num_heads,
        }
        
        # Initialize autoencoder based on type
        if autoencoder_type.lower() == 'unimodal':
            self.autoencoder = LatentAutoEncoder(self.encoder_params, self.decoder_params)
        elif autoencoder_type.lower() == 'multimodal':
            self.autoencoder = MultiModalLatentAutoEncoder(
                self.encoder_params, self.decoder_params, num_modalities=4, 
                combination_method=combination_method
            )
        elif autoencoder_type.lower() == "disentangled":
            self.autoencoder = DisentangledMultiModalLatentAutoEncoder(
                self.encoder_params, self.decoder_params, num_modalities=4,
                combination_method=combination_method
            )
        elif autoencoder_type.lower() == 'ae':
            self.autoencoder = SimpleAutoEncoder(self.encoder_params, self.decoder_params)
        else:
            raise ValueError("autoencoder_type must be either 'basic' or 'multimodal'")
        
        # Move models to device
        self.table_vectorizer = self.table_vectorizer.to(self.device)
        self.autoencoder = self.autoencoder.to(self.device)

    def prepare_data(
        self,
        df_folder=None,
        config_folder=None,
        output_folder=None,
        batch_size=100,
        shuffle=True,
        split_ratio=None,  # tuple of (train, val, test) ratios
        use_lmdb=False,
        lmdb_path=None,
        csv_log_path=None,
        num_workers=8
    ):
        """Process datasets and create DataLoader(s)
        
        Args:
            df_folder: folder containing parquet files (not used if use_lmdb=True)
            config_folder: folder containing config files (not used if use_lmdb=True)
            output_folder: folder to store processed files (not used if use_lmdb=True)
            batch_size: batch size for processing
            shuffle: whether to shuffle the data
            split_ratio: optional tuple of (train, val, test) ratios. Must sum to 1.
            use_lmdb: whether to use LMDB dataset instead of Parquet files
            lmdb_path: path to LMDB database (required if use_lmdb=True)
            csv_log_path: path to CSV log file for LMDB (required if use_lmdb=True)
            num_workers: number of worker processes for data loading
            
        Returns:
            If split_ratio is None: single DataLoader
            If split_ratio is provided: tuple of (train_loader, val_loader, test_loader)
        """
        if use_lmdb:
            # Validate LMDB parameters
            if lmdb_path is None or csv_log_path is None:
                raise ValueError("lmdb_path and csv_log_path must be provided when use_lmdb=True")

            
            # Create LMDB dataset
            dataset = Dataset_lmdb(lmdb_path, csv_log_path)
        else:
            # Original Parquet dataset processing
            if df_folder is None or config_folder is None or output_folder is None:
                raise ValueError("df_folder, config_folder, and output_folder must be provided when use_lmdb=False")
            
            # Process datasets
            batch_file_to_config = process_datasets(
                df_folder=df_folder,
                config_folder=config_folder,
                new_folder=output_folder,
                batch_size=batch_size,
                data_transformer_class=DataTransformer,
                quantile_transform=(self.numerical_transformation == "quantile")
            )
            
            # Create base dataset
            dataset = ParquetDataset(batch_file_to_config)
        
        if split_ratio is None:
            # Return single dataloader
            dataloader = DataLoader(
                dataset,
                batch_size=1,
                collate_fn=collate_fn,
                shuffle=shuffle,
                num_workers=num_workers
            )
            return dataloader
        
        # Validate split ratios
        train_ratio, val_ratio, test_ratio = split_ratio
        if not abs(sum(split_ratio) - 1.0) < 1e-6:
            raise ValueError("Split ratios must sum to 1")
        
        # Calculate split sizes
        total_size = len(dataset)
        train_size = int(train_ratio * total_size)
        val_size = int(val_ratio * total_size)
        test_size = total_size - train_size - val_size
        
        # Split the dataset
        generator = torch.Generator()
        generator.manual_seed(42)
        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
            dataset, 
            [train_size, val_size, test_size],
            generator=generator
        )
        
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=1,
            collate_fn=collate_fn,
            shuffle=shuffle,
            num_workers=num_workers
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=1,
            collate_fn=collate_fn,
            shuffle=False,  # No need to shuffle validation set
            num_workers=num_workers
        )
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=1,
            collate_fn=collate_fn,
            shuffle=False,  # No need to shuffle test set
            num_workers=num_workers
        )
        
        return train_loader, val_loader, test_loader

    def train(
        self,
        train_dataloader,
        val_dataloader=None,
        num_epochs=None,
        max_steps=None,
        learning_rate=1e-4,
        weight_decay=0.01,
        scheduler_type='cosine_warmup',
        scheduler_patience=3,
        scheduler_factor=0.1,
        warmup_percentage=0.08,
        min_lr=1e-5,
        init_beta=0.0,
        max_beta=1.0,
        max_beta_steps=1000,
        vectorizer_warmup_epochs=2,
        checkpoint_path="best_model.pth",
        save_interval=50,
        validation_interval=None,
        distributed=False,
        rank=0,
        world_size=1,
        early_stop_patience=20,
        interval_type="epoch",
        scheduler_interval="epoch",
        resume_from_checkpoint=False,
        skip_iters=0,
        total_scheduler_steps=None,
        gradient_accumulation_steps=1,
        mask_ratio=0.0,
        contrastive_weight=0.0,
        contrastive_temperature=0.07,
        base_contrastive_temperature=0.07,
        contrastive_dim=128,
        load_scheduler_state=True,
    ):
        """Train the model"""
            
        # Initialize optimizer and scheduler
        if DEBUGGING:
            combined_params = list(self.autoencoder.parameters()) 
        else:
            combined_params = list(self.autoencoder.parameters()) + list(self.table_vectorizer.parameters())

        optimizer = torch.optim.AdamW(
            combined_params,
            lr=learning_rate,
            weight_decay=weight_decay
        )
        
        # Calculate total steps if num_epochs is provided
        total_steps = max_steps
        if num_epochs is not None and max_steps is None:
            total_steps = len(train_dataloader) * num_epochs
            
        # Choose scheduler based on scheduler_type
        if scheduler_type == 'cosine_warmup':
            # Use total_scheduler_steps if provided, otherwise use total_steps
            scheduler_total_steps = total_scheduler_steps if total_scheduler_steps is not None else total_steps
            
            # Define the scheduler function for warmup + cosine annealing
            def get_linear_warmup_cosine_decay_scheduler(optimizer, warmup_steps, total_steps, min_lr):
                def lr_lambda(current_step):
                    if current_step < warmup_steps:
                        # Linear warmup
                        return float(current_step) / float(max(1, warmup_steps))
                    # Cosine annealing decay
                    progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
                    return max(min_lr / learning_rate, 0.5 * (1.0 + math.cos(math.pi * progress)))
                
                return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
            
            # Configure warmup steps based on percentage
            warmup_steps = int(warmup_percentage * scheduler_total_steps) if scheduler_total_steps else 100
            
            # Create scheduler with warmup + cosine annealing
            scheduler = get_linear_warmup_cosine_decay_scheduler(
                optimizer,
                warmup_steps=warmup_steps,
                total_steps=scheduler_total_steps,
                min_lr=min_lr
            )
        else:  # 'reduce_on_plateau' (default fallback)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='min',
                factor=scheduler_factor,
                patience=scheduler_patience,
                verbose=True
            )
        
        # Initialize trainer with all parameters
        trainer = VAETrainer(
            model=self.autoencoder,
            table_vectorizer=self.table_vectorizer,
            train_dataloader=train_dataloader,
            val_dataloader=val_dataloader,
            optimizer=optimizer,
            scheduler=scheduler,
            device=self.device,
            distributed=distributed,
            rank=rank,
            world_size=world_size,
            init_beta=init_beta,
            max_beta=max_beta,
            max_beta_steps=max_beta_steps,
            vectorizer_warmup_epochs=vectorizer_warmup_epochs,
            save_interval=save_interval,
            validation_interval=validation_interval,
            early_stop_patience=early_stop_patience,
            interval_type=interval_type,
            scheduler_interval=scheduler_interval,
            gradient_accumulation_steps=gradient_accumulation_steps,
            mask_ratio=mask_ratio,
            contrastive_weight=contrastive_weight,
            contrastive_temperature=contrastive_temperature,
            base_contrastive_temperature=base_contrastive_temperature,
            contrastive_dim=contrastive_dim,
            load_scheduler_state=load_scheduler_state
        )
        
        # Train model
        loss_history = trainer.train(
            num_epochs=num_epochs,
            max_steps=max_steps,
            checkpoint_path=checkpoint_path,
            resume_from_checkpoint=resume_from_checkpoint,
            skip_iters=skip_iters
        )
        
        self.is_fitted = True
        return loss_history

    def get_checkpoint(self):
        """Get model parameters for saving"""
        return {
            'vae_state': self.autoencoder.state_dict(),
            'vectorizer_state': self.table_vectorizer.state_dict(),
            'encoder_params': self.encoder_params,
            'decoder_params': self.decoder_params,
            'is_fitted': self.is_fitted
        }

    def load_checkpoint(self, checkpoint):
        """Load model parameters"""
        # Check if checkpoint comes from VAETrainer (_get_model_state) or from TableLatentModel.get_checkpoint
        if 'model' in checkpoint and 'vectorizer' in checkpoint:
            # === VAETrainer format ===
            model_state = checkpoint['model']
            vectorizer_state = checkpoint['vectorizer']

            # Strip potential DistributedDataParallel prefixes
            if any(k.startswith('module.') for k in model_state.keys()):
                model_state = {k.replace('module.', ''): v for k, v in model_state.items()}
            if any(k.startswith('module.') for k in vectorizer_state.keys()):
                vectorizer_state = {k.replace('module.', ''): v for k, v in vectorizer_state.items()}

            # Load states – use strict=False for resilience to minor architectural changes
            missing, unexpected = self.autoencoder.load_state_dict(model_state, strict=False)
            if missing or unexpected:
                print("[TableLatentModel] Warning while loading autoencoder state: ",
                      f"missing={missing}, unexpected={unexpected}")

            # Vectorizer may evolve frequently – always load with strict=False
            self.table_vectorizer.load_state_dict(vectorizer_state, strict=False)

            self.is_fitted = True

            # Store minimal trainer-related metadata if present (optional – may be absent in early checkpoints)
            #self.trainer_state = {
            #    'optimizer_state': checkpoint.get('optimizer', None),
            #    'epoch': checkpoint.get('epoch', 0),
            #    'loss': checkpoint.get('loss', None) or checkpoint.get('best_loss', None),
            #    'beta': checkpoint.get('beta', None)
            #}
        else:
            # === Original TableLatentModel format ===
            vae_state = checkpoint['vae_state']
            vectorizer_state = checkpoint['vectorizer_state']

            if any(k.startswith('module.') for k in vae_state.keys()):
                vae_state = {k.replace('module.', ''): v for k, v in vae_state.items()}
            if any(k.startswith('module.') for k in vectorizer_state.keys()):
                vectorizer_state = {k.replace('module.', ''): v for k, v in vectorizer_state.items()}

            self.autoencoder.load_state_dict(vae_state, strict=False)
            self.table_vectorizer.load_state_dict(vectorizer_state, strict=False)

            # Preserve meta-information
            self.encoder_params = checkpoint['encoder_params']
            self.decoder_params = checkpoint['decoder_params']
            self.is_fitted = True
            #self.trainer_state = None

    def table_to_latent(self, df, config, batch_size=32, preprocessed=True, inference=False):
        """Convert table to latent representation
        Args:
            df: pandas DataFrame
            config: table config dictionary
            batch_size: batch size for processing
            preprocessed: if True, assumes data is already preprocessed and skips dataloader preparation
        Returns:
            latent_codes: tensor of shape (n_rows, n_latents, d_latent)
        """
        if not self.is_fitted:
            raise RuntimeError("Model must be fitted before converting table to latent")
        
        if preprocessed:
            # Split df into batches and create simple dataloader
            n_samples = len(df)
            dataloader = [
                {
                    'config': config,
                    'df_batch': df.iloc[i:i + batch_size]
                }
                for i in range(0, n_samples, batch_size)
            ]
        else:
            # Create temporary folders and use original dataloader logic
            tmp_df_folder = tempfile.mkdtemp()
            tmp_config_folder = tempfile.mkdtemp()
            tmp_output_folder = tempfile.mkdtemp()
            
            # Save df and config temporarily
            df.to_parquet(os.path.join(tmp_df_folder, "temp.parquet"))
            with open(os.path.join(tmp_config_folder, "temp.json"), "w") as f:
                json.dump(config, f)
            
            # Prepare dataloader
            dataloader = self.prepare_data(
                tmp_df_folder, 
                tmp_config_folder,
                tmp_output_folder,
                batch_size=batch_size,
                shuffle=False
            )
                
        
        all_latents = []
        
        # Process each batch
        for batch in dataloader:
            config_batch = batch['config']
            df_batch = batch['df_batch']
            batch_size = len(df_batch)
            
            # Vectorize table
            table_tensor = self.table_vectorizer.vectorize(df_batch, config_batch).to(self.device)
            n_cols = table_tensor.shape[1]
            
            # Encode metadata and column information
            meta, column_names, dtypes, dist = self.table_vectorizer.encode_meta(config_batch)
            meta = meta.unsqueeze(0).repeat(batch_size, 1).to(self.device)
            column_names = column_names.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device)
            dtypes = dtypes.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device)
            dist = dist.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device)
            
            # Create attention mask
            attention_mask = torch.ones(batch_size, n_cols, dtype=torch.bool, device=self.device)
            
            # Get latent representation
            with torch.no_grad():
                latent, _, _ = self.autoencoder.encode(
                    table_tensor, column_names, dtypes, meta, attention_mask, deterministic=inference,
                    dist=dist
                )
                if DEBUGGING:
                    latent = table_tensor # DEBUGGING PURPOSES ONLY, skipping VAE.
                all_latents.append(latent)
        
        # Stack all latents
        final_latent = torch.cat(all_latents, dim=0)
        
        if not preprocessed:
            # Cleanup temporary folders
            shutil.rmtree(tmp_df_folder)
            shutil.rmtree(tmp_config_folder)
            shutil.rmtree(tmp_output_folder)
        
        return final_latent

    def latent_to_table(self, latent_code, config, batch_size=32):
        """Convert latent representation back to table
        Args:
            latent_code: tensor of shape (n_rows, n_latents, d_latent)
            config: table config dictionary
            batch_size: batch size for processing
        Returns:
            reconstructed_df: pandas DataFrame
        """
        if not self.is_fitted:
            raise RuntimeError("Model must be fitted before converting latent to table")
        
        # Split latent code into batches
        n_samples = latent_code.shape[0]
        latent_batches = torch.split(latent_code, batch_size)
        
        all_dfs = []
        
        for batch_latent in latent_batches:
            batch_size = batch_latent.shape[0]
                
            # Convert back to dataframe
            if DEBUGGING:
                decoded_embedding = batch_latent
            else:
                # Encode metadata and column information
                meta, column_names, dtypes, dist = self.table_vectorizer.encode_meta(config)
                meta = meta.unsqueeze(0).repeat(batch_size, 1).to(self.device)
                column_names = column_names.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device)
                dtypes = dtypes.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device)
                dist = dist.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device)
                
                # Decode latent representation
                with torch.no_grad():
                    decoded_embedding = self.autoencoder.decode(
                        batch_latent, column_names, dtypes, meta, dist=dist
                    )
            batch_df = self.table_vectorizer.inverse_vectorize(
                decoded_embedding, config, mode="inference"
            )
            all_dfs.append(batch_df)
        
        # Concatenate all dataframes
        final_df = pd.concat(all_dfs, axis=0, ignore_index=True)
        
        return final_df