import torch
import torch.nn as nn
import pandas as pd
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
import os
import copy
import torch.nn.functional as F
from typing import List, Dict, Optional, Tuple, Union
import time

from .base import ColumnVectorizer

from latent.vae.perceive.perceiveResampler import PerceiverAutoEncoder

from tqdm import tqdm

# Simple training loop function for testing
def test_training_loop():
    print("\n" + "="*50)
    print("RUNNING TRAINING LOOP TEST WITH PERCEIVER")
    print("="*50)
    
    # Create sample data
    texts = pd.Series([
        "Hello world!",
        "This is a test.",
        "BART is a powerful model.",
        "Let's see how well it works.",
        "Text reconstruction in action."
    ])
    
    # Initialize vectorizer with Perceiver
    output_dim = 1024
    vectorizer = TextVectorizer(output_dim=output_dim, use_perceiver=True, num_latents=8)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    vectorizer.to(device)
    
    # Create optimizer
    optimizer = torch.optim.Adam([p for p in vectorizer.parameters() if p.requires_grad], lr=5e-5)
    
    # Add ReduceLROnPlateau scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min',           # Reduce LR when loss stops decreasing
        factor=0.5,           # Multiply LR by this factor when reducing
        patience=200,         # Number of epochs with no improvement after which LR will be reduced
        verbose=True,         # Print message when LR is reduced
        min_lr=1e-6           # Lower bound on the learning rate
    )
    
    # Simple training loop
    num_epochs = 2500
    
    pbar = tqdm(range(num_epochs), desc="Training")
    losses = []
    best_loss = float('inf')
    best_model_state = None
    save_dir = "model_checkpoints"
    os.makedirs(save_dir, exist_ok=True)
    
    for epoch in pbar:
        # Forward pass
        embeddings = vectorizer._vectorize(texts, config={})
        train_output = vectorizer._inverse_vectorize(embeddings, config={}, mode='train')
        loss = vectorizer._compute_loss(train_output, texts, config={})
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update learning rate based on loss
        scheduler.step(loss)
        
        # Track best loss and save best model
        current_loss = loss.item()
        losses.append(current_loss)
        
        if current_loss < best_loss:
            best_loss = current_loss
            # Save best model state (deep copy to avoid reference issues)
            best_model_state = copy.deepcopy(vectorizer.state_dict())
        
        # Save checkpoint every 100 epochs
        if (epoch + 1) % 500 == 0:
            checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pt")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': vectorizer.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': current_loss,
                'best_loss': best_loss,
            }, checkpoint_path)
            
            # Also save best model so far
            best_model_path = os.path.join(save_dir, "best_model.pt")
            if best_model_state is not None:
                torch.save(best_model_state, best_model_path)
        
        # Update progress bar with current loss and learning rate
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix({
            "loss": f"{current_loss:.4f}", 
            "best_loss": f"{best_loss:.4f}",
            "avg_loss_last_10": f"{sum(losses[-10:]) / min(len(losses), 10):.4f}",
            "lr": f"{current_lr:.2e}"
        })
    
    # Restore best model at the end of training
    print("\nTraining completed. Restoring best model with loss:", best_loss)
    if best_model_state is not None:
        vectorizer.load_state_dict(best_model_state)
    
    # Test reconstruction with best model
    print("\nReconstruction with best model:")
    embeddings = vectorizer._vectorize(texts, config={})
    reconstructed_texts = vectorizer._inverse_vectorize(embeddings, config={}, mode='inference')
    
    for original, reconstructed in zip(texts, reconstructed_texts):
        print(f"Original: {original:30} | Reconstructed: {reconstructed}")

class TextVectorizer(ColumnVectorizer):
    MODEL_NAME = "facebook/bart-base"
    MAX_LENGTH = 64

    def __init__(self, output_dim=None, use_perceiver=True, num_latents=8, finetune_bart=False):
        """
        Initialize the text vectorizer using BART.
        
        Args:
            output_dim (int, optional): If provided, must be max_seq_len * d_model
            use_perceiver (bool): Whether to use PerceiverAE for dimensionality reduction
            num_latents (int): Number of latents for PerceiverAE
        """
        # First load BART to get dimensions
        bart = BartForConditionalGeneration.from_pretrained(self.MODEL_NAME)
        tokenizer = BartTokenizer.from_pretrained(self.MODEL_NAME)
        self.finetune_bart = finetune_bart
        
        # Calculate dimensions
        d_model = bart.config.d_model
        
        
        if use_perceiver:
            # When using Perceiver, output_dim is num_latents * d_model
            flat_dim = output_dim if output_dim is not None else d_model
            assert d_model % num_latents == 0, "d_model must be divisible by num_latents"
            perceiver_dim = d_model  # Keep same dimension as BART
            
            perceiver = PerceiverAutoEncoder(
                dim_lm=d_model,           # BART's hidden dimension
                dim_ae=flat_dim // num_latents,           # Keep same dimension
                depth=1,                  # Number of self-attention layers
                dim_head=64,              # Attention head dimension
                num_encoder_latents=num_latents,
                num_decoder_latents=self.MAX_LENGTH,  # Reconstruct full sequence length
                max_seq_len=self.MAX_LENGTH,
                ff_mult=4
            )
        else:
            flat_dim = self.MAX_LENGTH * d_model
        
        if output_dim is not None and output_dim != flat_dim:
            raise ValueError(
                f"output_dim must be {'num_latents' if use_perceiver else 'MAX_LENGTH'} * d_model = {flat_dim}, "
                f"but got {output_dim}"
            )
        
        # Call parent's init first
        super().__init__(output_dim=flat_dim, accepted_dtype="object")
        
        # Now set instance attributes after parent initialization
        self.bart = bart
        self.tokenizer = tokenizer
        self.d_model = d_model
        self.flat_dim = flat_dim
        self.use_perceiver = use_perceiver
        self.num_latents = num_latents
        if self.use_perceiver:
            self.perceiver = perceiver
            self.perceiver_dim = perceiver_dim
        
        # Freeze BART
        for param in self.bart.parameters():
            param.requires_grad = False

    def to(self, device):
        """Move the model to specified device."""
        super().to(device)
        if self.finetune_bart:
            self.bart = self.bart.to(device)
        if self.use_perceiver:
            self.perceiver = self.perceiver.to(device)
        return self
        
    def is_trainable(self):
        return True
    
    def _vectorize(self, column, config):
        """
        Transform text column into embeddings using BART encoder.
        
        Returns tensor of shape (batch_size, max_seq_len * d_model)
        """
        # Tokenize all texts in batch
        inputs = self.tokenizer(
            column.tolist(),
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.MAX_LENGTH
        ).to(self.device)
        
        # Generate embeddings
        with torch.no_grad():
            outputs = self.bart.model.encoder(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                return_dict=True
            )
            # Shape: (batch_size, seq_len, d_model)
            hidden_states = outputs.last_hidden_state
            bart_latent_len = hidden_states.shape[1]
            #print("hidden_states shape out of bart: ", hidden_states.shape)
            
            # Pad or truncate to MAX_LENGTH
            batch_size = hidden_states.shape[0]
            padded = torch.zeros(
                (batch_size, self.MAX_LENGTH, self.d_model),
                device=self.device
            )
            seq_len = min(bart_latent_len, self.MAX_LENGTH)
            padded[:, :seq_len, :] = hidden_states[:, :seq_len, :]
            #print("padded shape: ", padded.shape)
            
            if self.use_perceiver:
                # Use Perceiver to encode to fixed-length representation
                # Create attention mask for padded sequence
                attention_mask = torch.ones((batch_size, self.MAX_LENGTH), device=self.device)
                attention_mask[:, seq_len:] = 0
                
                # Encode with Perceiver: (B, seq_len, d) -> (B, num_latents, d)
                encoded = self.perceiver.encode(padded, attention_mask)
                #print("perceiver encoded shape: ", encoded.shape)
                # Flatten: (B, num_latents, d) -> (B, num_latents * d)
                return encoded.reshape(batch_size, -1)
            else:
                # Original flattening method
                return padded.reshape(batch_size, -1)

    def _inverse_vectorize(self, column_vector, config, mode='inference'):
        """
        Convert column vector back to text representation.
        
        Args:
            column_vector: The vector representation of the column
            mode: 'train' or 'inference'
            
        Returns:
            For training: encoder hidden states to be used in loss computation
            For inference: decoded text
        """
        # Reshape column vector to match BART encoder output shape
        batch_size = column_vector.shape[0]
        
        if self.use_perceiver:
            # When using perceiver, we need to decode the latent representation
            # Reshape to (batch_size, num_latents, flat_dim // num_latents)
            latent_dim = self.flat_dim // self.num_latents
            latents = column_vector.view(batch_size, self.num_latents, latent_dim)
            
            # Decode latents to full sequence length
            encoder_hidden_states = self.perceiver.decode(latents)
        else:
            # Without perceiver, simply reshape to (batch_size, MAX_LENGTH, d_model)
            encoder_hidden_states = column_vector.view(batch_size, self.MAX_LENGTH, self.d_model)
        
        if mode == 'train':
            # For training, just return encoder hidden states
            # Loss computation will be handled in _compute_loss
            return {
                'encoder_hidden_states': encoder_hidden_states
            }
        else:
            # For inference, generate text using the BART model
            # Create dummy decoder input IDs
            decoder_input_ids = torch.ones(
                (batch_size, 1),
                dtype=torch.long,
                device=self.device
            ) * self.bart.config.decoder_start_token_id
            
            # Create proper encoder_outputs format that BART's generate method expects
            # It needs to be a BaseModelOutput object or a dict with 'last_hidden_state'
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_hidden_states
            )
            
            # Generate text using BART's generate method
            generated_ids = self.bart.generate(
                encoder_outputs=encoder_outputs,
                decoder_input_ids=decoder_input_ids,
                max_length=self.MAX_LENGTH,
                num_beams=4,
                early_stopping=True
            )
            
            # Decode the generated IDs to text
            generated_texts = self.tokenizer.batch_decode(
                generated_ids, skip_special_tokens=True
            )
            return pd.Series(generated_texts)

    def _compute_loss(self, vectorized_column, target_column, config={}):
        """
        Compute the loss between the vectorized column and the target column.
        
        Args:
            vectorized_column: Dictionary containing encoder hidden states
            target_column: Target text column (pandas Series)
            
        Returns:
            Loss value
        """
        encoder_hidden_states = vectorized_column['encoder_hidden_states']
        
        # Convert pandas Series to list of strings
        target_texts = target_column.astype(str).tolist()
        
        # Tokenize target text
        target_encodings = self.tokenizer(
            target_texts,
            padding='max_length',
            truncation=True,
            max_length=self.MAX_LENGTH,
            return_tensors='pt'
        ).to(self.device)
        
        # Get input IDs for labels
        labels = target_encodings.input_ids
        
        # Create attention mask to ignore padding tokens in loss calculation
        # This ensures we don't compute loss on padded tokens
        attention_mask = target_encodings.attention_mask
        #print("attention_mask shape: ", attention_mask.shape)
        #print("attention_mask: ", attention_mask)
        
        # Pass to BART with proper masking
        outputs = self.bart(
            labels=labels,
            encoder_outputs=[encoder_hidden_states],
            attention_mask=attention_mask,
            return_dict=True
        )
        
        return outputs.loss

    def _vectorize_batch(self, columns: List[pd.Series], configs: List[Dict]) -> List[torch.Tensor]:
        """
        Batch implementation for transforming multiple text columns into vectors.
        
        Args:
            columns (List[pandas.Series]): The text columns to transform.
            configs (List[dict]): The configuration dictionaries (not used for text).
            
        Returns:
            List[torch.Tensor]: List of transformed tensors.
        """
        # Collect information about each column's size
        column_sizes = [len(column) for column in columns]
        total_samples = sum(column_sizes)
        
        # Concatenate all columns into a single Series
        all_texts = pd.Series([text for column in columns for text in column.astype(str).tolist()])
        
        # Vectorize all texts at once
        all_embeddings = self._vectorize(all_texts, config={})
        
        # Split embeddings back into separate tensors for each column
        result_tensors = []
        start_idx = 0
        for size in column_sizes:
            end_idx = start_idx + size
            result_tensors.append(all_embeddings[start_idx:end_idx])
            start_idx = end_idx
            
        return result_tensors
    
    def _inverse_vectorize_batch(self, tensors: List[torch.Tensor], configs: List[Dict], mode: str) -> List:
        """
        Batch implementation for inverse transformation of vectors back to text.
        
        Args:
            tensors (List[torch.Tensor]): The tensors to inverse transform.
            configs (List[dict]): The configuration dictionaries (not used for text).
            mode (str): 'inference' or 'train'
            
        Returns:
            List: List of reconstructed text for each tensor
        """
        # Collect information about each tensor's size
        tensor_sizes = [tensor.shape[0] for tensor in tensors]
        
        # Concatenate all tensors
        all_tensors = torch.cat(tensors, dim=0)
        
        # Process all tensors at once
        all_output = self._inverse_vectorize(all_tensors, config={}, mode=mode)
        
        # Split results based on original sizes
        result = []
        if mode == 'train':
            # For training mode, split encoder_hidden_states
            hidden_states = all_output['encoder_hidden_states']
            start_idx = 0
            for size in tensor_sizes:
                end_idx = start_idx + size
                result.append({
                    'encoder_hidden_states': hidden_states[start_idx:end_idx]
                })
                start_idx = end_idx
        else:
            # For inference mode, split Series of decoded texts
            start_idx = 0
            for size in tensor_sizes:
                end_idx = start_idx + size
                result.append(all_output[start_idx:end_idx].reset_index(drop=True))
                start_idx = end_idx
                
        return result
    
    def _compute_batch_loss(self, reconstructed_values_list: List[Dict], 
                           target_columns: List[pd.Series], 
                           configs: List[Dict]) -> torch.Tensor:
        """
        Compute the combined loss for multiple text columns.
        
        Args:
            reconstructed_values_list (List[Dict]): List of reconstruction outputs for each column
            target_columns (List[pd.Series]): List of target text columns
            configs (List[Dict]): List of configurations (not used for text)
            
        Returns:
            torch.Tensor: The combined loss from all columns
        """
        # First ensure all reconstructed values have the expected structure
        for i, output in enumerate(reconstructed_values_list):
            if not isinstance(output, dict) or 'encoder_hidden_states' not in output:
                raise ValueError(f"Invalid reconstructed output at index {i}. Expected dict with 'encoder_hidden_states' key.")
        
        # Collect all encoder hidden states and target texts
        all_hidden_states = torch.cat([
            output['encoder_hidden_states'] for output in reconstructed_values_list
        ], dim=0)
        
        # Concatenate all target columns
        all_target_texts = pd.Series([
            text for column in target_columns for text in column.astype(str).tolist()
        ])
        
        # Create a combined dictionary for loss computation
        combined_output = {
            'encoder_hidden_states': all_hidden_states
        }
        
        # Compute loss on the combined data
        return self._compute_loss(combined_output, all_target_texts, {})

if __name__ == "__main__":
    # Create sample data
    texts = pd.Series([
        "Hello world!",
        "This is a test.",
        "BART is a powerful model.",
        "Let's see how well it works.",
        "Text reconstruction in action."
    ])
    
    # Test both regular and perceiver versions
    for use_perceiver in [False, True]:
        print(f"\n{'='*50}")
        print(f"Testing with {'Perceiver' if use_perceiver else 'Regular'} mode")
        print(f"{'='*50}")
        
        # Initialize vectorizer
        output_dim = 1024 if use_perceiver else None
        vectorizer = TextVectorizer(output_dim=output_dim, use_perceiver=use_perceiver, num_latents=8)
        vectorizer.to("cuda" if torch.cuda.is_available() else "cpu")
        
        print("Original texts:")
        print(texts)
        print("\n")
        
        # Vectorize using public method
        print("Vectorizing texts...")
        embeddings = vectorizer.vectorize(texts, config={})
        print(f"Embedding shape: {embeddings.shape}")
        print(f"Embedding sample (first 5 dimensions):\n{embeddings[0, :5]}")
        print("\n")
        
        # Test training mode reconstruction using public method
        print("Testing training mode reconstruction...")
        reconstructed, loss = vectorizer.inverse_vectorize(embeddings, config={}, mode='train', target_column=texts)
        print(f"Reconstruction loss: {loss.item()}")
        
        # Test inference mode using public method
        print("\nTesting inference mode reconstruction...")
        reconstructed_texts = vectorizer.inverse_vectorize(embeddings, config={}, mode='inference')
        
        # Compare results
        print("\nReconstruction results:")
        for original, reconstructed in zip(texts, reconstructed_texts):
            print(f"Original: {original:30} | Reconstructed: {reconstructed}")
            
    # Test training loop if enabled
    test_train = False
    if test_train:
        test_training_loop()
    
    # Test batch processing
    print("\nTest: Batch text processing")
    columns = [
        pd.Series(["Hello world!", "This is a test."]),
        pd.Series(["BART is powerful.", "Let's see how it works."]),
        pd.Series(["Text reconstruction in action.", "Batch processing is efficient."])
    ]
    configs = [{}, {}, {}]
    
    # Initialize vectorizer with Perceiver for faster processing
    vectorizer = TextVectorizer(output_dim=1024, use_perceiver=True, num_latents=8)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    vectorizer.to(device)
    
    # Test batch vectorization using public method
    print("Testing batch vectorization...")
    start_time = time.time()
    batch_vectors = vectorizer.vectorize_batch(columns, configs)
    batch_time = time.time() - start_time
    
    print(f"Number of output tensors: {len(batch_vectors)}")
    for i, tensor in enumerate(batch_vectors):
        print(f"Tensor {i} shape: {tensor.shape}")
    
    # Compare with sequential processing
    print("\nComparing with sequential processing...")
    start_time = time.time()
    sequential_vectors = [vectorizer.vectorize(col, {}) for col in columns]
    sequential_time = time.time() - start_time
    
    print(f"Sequential processing time: {sequential_time:.4f}s")
    print(f"Batch processing time: {batch_time:.4f}s")
    print(f"Speedup: {sequential_time / batch_time:.2f}x")
    
    # Test batch inverse vectorization (inference mode) using public method
    print("\nTesting batch inverse vectorization (inference mode)...")
    decoded_columns = vectorizer.inverse_vectorize_batch(batch_vectors, configs, mode='inference')
    
    print("Decoded texts:")
    for i, col in enumerate(decoded_columns):
        print(f"Column {i}:")
        for j, text in enumerate(col):
            print(f"  {j}: {text}")
    
    # Test batch inverse vectorization (training mode) using public method
    print("\nTesting batch inverse vectorization (training mode)...")
    reconstructed_values, batch_loss = vectorizer.inverse_vectorize_batch(
        batch_vectors, configs, mode='train', target_columns=columns)
    
    print(f"Number of reconstructed outputs: {len(reconstructed_values)}")
    print(f"Batch loss: {batch_loss.item()}")
    
    # Compare with individual inverse vectorization
    print("\nComparing with sequential inverse vectorization...")
    individual_losses = []
    for i, (vec, col) in enumerate(zip(sequential_vectors, columns)):
        _, loss = vectorizer.inverse_vectorize(vec, {}, mode='train', target_column=col)
        individual_losses.append(loss)
        print(f"Column {i} loss: {loss.item()}")
    
    total_individual_loss = sum(loss.item() for loss in individual_losses)
    print(f"Sum of individual losses: {total_individual_loss}")
    print(f"Batch loss: {batch_loss.item()}")
    print(f"Difference: {abs(batch_loss.item() - total_individual_loss):.6f}")