"""
Few-Shot Learning Analysis for TAN
Loads a pre-trained TAN checkpoint and fine-tunes it on the GoEmotions dataset.
"""
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple
from dataclasses import dataclass
import json
import logging
from tqdm import tqdm
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader, Dataset, Subset
from transformers import AutoTokenizer
from torch.optim import AdamW
import matplotlib.pyplot as plt
import seaborn as sns
from copy import deepcopy
import warnings

# Import the official Hugging Face datasets library
from datasets import load_dataset


from tan_architecture import TANForMultiLabelClassification, TANConfig

warnings.filterwarnings('ignore')

# Setup clean logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- 1. DATA HANDLING ---
class GoEmotionsHFDataset(Dataset):
    """
    A wrapper for the Hugging Face GoEmotions dataset to handle tokenization.
    """
    def __init__(self, hf_dataset, tokenizer, max_len: int = 128):
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.num_labels = 28 # GoEmotions has 27 emotions + 'neutral'

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        item = self.hf_dataset[idx]
        text = item['text']
        
        # Create multi-hot encoded labels from the list of indices
        labels = torch.zeros(self.num_labels)
        labels[item['labels']] = 1

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': labels
        }
        
    def get_all_labels(self) -> np.ndarray:
        """Helper to extract all labels for stratified sampling."""
        all_labels = np.zeros((len(self), self.num_labels))
        for i in range(len(self)):
            # Access the original HF dataset to get label indices directly
            label_indices = self.hf_dataset[i]['labels']
            all_labels[i, label_indices] = 1
        return all_labels

# --- 2. CONFIGURATION ---
@dataclass
class FewShotConfig:
    """Configuration for few-shot learning experiments."""
    tan_checkpoint: Path = Path('goemotion_best_model.pt')
    results_dir: Path = Path('few_shot_results')
    shots_per_class: List[int] = None
    num_classes: int = 28
    num_trials: int = 3
    max_epochs: int = 15
    early_stopping_patience: int = 3
    batch_size: int = 16
    learning_rate: float = 2e-4
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'

    def __post_init__(self):
        if self.shots_per_class is None:
            self.shots_per_class = [1, 5, 10, 25, 50, 100] # Kept small for faster execution
        self.results_dir.mkdir(exist_ok=True)

# --- 3. DATA SAMPLING ---
class FewShotSampler:
    """Samples few-shot data with a balanced class distribution."""
    def __init__(self, config: FewShotConfig):
        self.config = config

    def create_few_shot_split(self, full_dataset: Dataset, labels: np.ndarray, n_shots: int, seed: int) -> Tuple[Subset, Subset]:
        """Creates training and validation subsets for a few-shot trial."""
        np.random.seed(seed)
        train_indices, val_indices = [], []
        used_indices = set()

        for class_idx in range(self.config.num_classes):
            class_samples = np.where(labels[:, class_idx] == 1)[0]
            available_samples = [idx for idx in class_samples if idx not in used_indices]
            
            if len(available_samples) < n_shots + 5: # Need enough for train + val
                continue
                
            np.random.shuffle(available_samples)
            
            train_samples = available_samples[:n_shots]
            val_samples = available_samples[n_shots : n_shots + 5]
            
            train_indices.extend(train_samples)
            val_indices.extend(val_samples)
            used_indices.update(train_samples + val_samples)
        
        train_indices = sorted(list(set(train_indices)))
        val_indices = sorted(list(set(val_indices)))

        return Subset(full_dataset, train_indices), Subset(full_dataset, val_indices)

# --- 4. TRAINING & EVALUATION ---
class FewShotTrainer:
    """Handles the training and evaluation loop for a single few-shot trial."""
    def __init__(self, config: FewShotConfig):
        self.config = config

    def evaluate(self, model: nn.Module, dataloader: DataLoader) -> Dict:
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in dataloader:
                outputs = model(
                    input_ids=batch['input_ids'].to(self.config.device),
                    attention_mask=batch['attention_mask'].to(self.config.device)
                )
                logits = outputs['logits']
                preds = (torch.sigmoid(logits) > 0.5).float()
                all_preds.append(preds.cpu().numpy())
                all_labels.append(batch['labels'].numpy())
                
        all_preds = np.vstack(all_preds)
        all_labels = np.vstack(all_labels)
        
        return {'f1_macro': f1_score(all_labels, all_preds, average='macro', zero_division=0)}

    def train(self, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader) -> nn.Module:
        optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=self.config.learning_rate)
        best_val_f1 = 0
        patience_counter = 0
        best_model_state = deepcopy(model.state_dict())

        for epoch in range(self.config.max_epochs):
            model.train()
            for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config.max_epochs}", leave=False):
                optimizer.zero_grad()
                outputs = model(
                    input_ids=batch['input_ids'].to(self.config.device),
                    attention_mask=batch['attention_mask'].to(self.config.device),
                    labels=batch['labels'].to(self.config.device)
                )
                loss = outputs['loss']
                loss.backward()
                optimizer.step()
            
            val_metrics = self.evaluate(model, val_loader)
            val_f1 = val_metrics['f1_macro']
            logger.info(f"Epoch {epoch+1}: Validation F1-Macro = {val_f1:.4f}")
            
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                best_model_state = deepcopy(model.state_dict())
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= self.config.early_stopping_patience:
                    logger.info("Early stopping triggered.")
                    break
        
        model.load_state_dict(best_model_state)
        return model

# --- 5. VISUALIZATION ---
def create_visualizations(results: Dict, save_dir: Path):
    """Creates performance plots for few-shot results."""
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(10, 7))

    shots = sorted(results.keys())
    mean_f1 = [np.mean(results[s]['scores']) for s in shots]
    std_f1 = [np.std(results[s]['scores']) for s in shots]
    
    mean_f1 = np.array(mean_f1)
    std_f1 = np.array(std_f1)
    
    ax.plot(shots, mean_f1, marker='o', linestyle='-', label='TAN')
    ax.fill_between(shots, mean_f1 - std_f1, mean_f1 + std_f1, alpha=0.2)

    ax.set_xlabel('Shots per Class', fontsize=12)
    ax.set_ylabel('F1-Macro Score', fontsize=12)
    ax.set_title('TAN Few-Shot Learning Performance on GoEmotions', fontsize=14, fontweight='bold')
    ax.set_xscale('log')
    ax.legend()
    ax.grid(True, which="both", ls="--", c='0.7')
    plt.tight_layout()
    
    save_path = save_dir / 'tan_few_shot_performance.png'
    plt.savefig(save_path, dpi=300)
    logger.info(f"Visualization saved to {save_path}")
    plt.close()

# --- 6. MAIN EXECUTION ---
def main():
    config = FewShotConfig()
    logger.info("=" * 60)
    logger.info("Starting TAN Few-Shot Learning Analysis")
    logger.info(f"Using device: {config.device}")
    logger.info("=" * 60)

    # 1. Automatically download and load the GoEmotions dataset
    try:
        logger.info("Downloading and preparing GoEmotions dataset...")
        dataset_dict = load_dataset("go_emotions", "simplified")
    except Exception as e:
        logger.error(f"Failed to download dataset. Check internet connection or firewall. Error: {e}")
        return

    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    
    # Create dataset wrappers
    train_dataset = GoEmotionsHFDataset(dataset_dict['train'], tokenizer)
    train_labels = train_dataset.get_all_labels()

    # 2. Run Experiments
    sampler = FewShotSampler(config)
    trainer = FewShotTrainer(config)
    all_results = {}

    for n_shots in config.shots_per_class:
        logger.info(f"\n--- Running Experiment: {n_shots} Shots per Class ---")
        trial_scores = []
        
        for trial in range(config.num_trials):
            logger.info(f"--- Trial {trial + 1}/{config.num_trials} ---")
            
            # a. Load pre-trained TAN model from the checkpoint
            tan_config = TANConfig(vocab_size=tokenizer.vocab_size)
            model = TANForMultiLabelClassification(tan_config, num_labels=config.num_classes)
            
            if config.tan_checkpoint.exists():
                checkpoint = torch.load(config.tan_checkpoint, map_location='cpu', weights_only=False)
                state_dict = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint))
                model.load_state_dict(state_dict, strict=False)
                logger.info(f"Loaded weights from {config.tan_checkpoint}")
            else:
                logger.error(f"Checkpoint not found at {config.tan_checkpoint}. Aborting.")
                return
            model.to(config.device)

            # b. Create data splits
            train_subset, val_subset = sampler.create_few_shot_split(train_dataset, train_labels, n_shots, seed=42 + trial)
            train_loader = DataLoader(train_subset, batch_size=config.batch_size, shuffle=True)
            val_loader = DataLoader(val_subset, batch_size=config.batch_size)
            
            if len(train_loader) == 0 or len(val_loader) == 0:
                logger.warning(f"Skipping trial for {n_shots}-shots due to insufficient data for sampling.")
                continue

            # c. Fine-tune and evaluate the model
            fine_tuned_model = trainer.train(model, train_loader, val_loader)
            final_metrics = trainer.evaluate(fine_tuned_model, val_loader)
            logger.info(f"Trial {trial + 1} Final F1-Macro: {final_metrics['f1_macro']:.4f}")
            trial_scores.append(final_metrics['f1_macro'])
            
        if trial_scores:
            all_results[n_shots] = {'scores': trial_scores}
            logger.info(f"Average F1-Macro for {n_shots}-shots: {np.mean(trial_scores):.4f} ± {np.std(trial_scores):.4f}")

    # 3. Save and Visualize Results
    if all_results:
        results_path = config.results_dir / 'tan_few_shot_results.json'
        with open(results_path, 'w') as f:
            json.dump(all_results, f, indent=2)
        logger.info(f"\nFull results saved to {results_path}")
        
        create_visualizations(all_results, config.results_dir)
    
    logger.info("="*60)
    logger.info("Few-Shot Analysis Complete.")
    logger.info("="*60)


if __name__ == "__main__":
    main()