import pandas as pd
import numpy as np
from datasets import Dataset, concatenate_datasets
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.model_selection import KFold
import torch
import json
import wandb
from typing import Dict, List
from data_utils import load_halawi_data
from accelerate import Accelerator
import transformers

def load_data(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return [{
        'text': item['text'],
        'label': item['resolution']
    } for item in data]

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

def train_fold(
    train_dataset, 
    val_dataset, 
    model_name: str,
    fold_idx: int,
    base_output_dir: str
) -> Dict[str, float]:
    
    # Initialize accelerator
    accelerator = Accelerator()
    
    # Initialize tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2,
        problem_type="single_label_classification"
    )

    # Get all features except for 'label' and 'resolution'
    all_features = list(train_dataset.features.keys())
    # Exclude problematic features
    excluded_features = ['label', 'resolution', 'is_resolved', 'data_source', 'community_predictions', 
                         'gpt_3p5_category', 'question_type', 'url', 'extracted_urls', 
                         'date_resolve_at', 'date_created_at']
    feature_names = [feat for feat in all_features if feat not in excluded_features]
    text_column = 'text'  # The main text column
    
    # Identify numeric and categorical features
    numeric_features = []
    categorical_features = []
    for feature in feature_names:
        if feature == text_column:
            continue
        try:
            # Check if feature exists in first example
            if feature not in train_dataset[0]:
                print(f"Warning: Feature '{feature}' not found in dataset")
                continue
                
            # Check feature type - this is a simple heuristic, adjust based on your data
            if isinstance(train_dataset[0][feature], (int, float)) or (
                isinstance(train_dataset[0][feature], str) and 
                train_dataset[0][feature] and 
                train_dataset[0][feature].replace('.', '', 1).isdigit()
            ):
                numeric_features.append(feature)
            else:
                categorical_features.append(feature)
        except Exception as e:
            print(f"Warning: Skipping feature '{feature}' due to error: {e}")
    
    print(f"Text column: {text_column}")
    print(f"Numeric features: {numeric_features}")
    print(f"Categorical features: {categorical_features}")

    def prepare_features(examples):
        # Process text with tokenizer
        tokenized = tokenizer(
            examples[text_column],
            padding="max_length",
            truncation=True,
            max_length=1024
        )
        
        # Add numeric features
        for feature in numeric_features:
            values = examples[feature]
            # Convert to float and handle missing values
            tokenized[f"numeric_{feature}"] = [float(v) if v is not None else 0.0 for v in values]
        
        # Add categorical features
        for feature in categorical_features:
            values = examples[feature]
            # Simple encoding - you might want to use one-hot encoding for production
            tokenized[f"categorical_{feature}"] = [str(v) if v is not None else "" for v in values]
            
        return tokenized

    # Process datasets with all features
    processed_train = train_dataset.map(prepare_features, batched=True)
    processed_val = val_dataset.map(prepare_features, batched=True)
    
    # Define custom model that combines text and other features
    class CombinedFeaturesModel(torch.nn.Module):
        def __init__(self, model_name, num_labels, numeric_feature_count, categorical_feature_count):
            super().__init__()
            self.transformer = AutoModelForSequenceClassification.from_pretrained(
                model_name, num_labels=num_labels
            )
            # Get the hidden size from the transformer config
            hidden_size = self.transformer.config.hidden_size
            
            # Additional inputs size
            self.numeric_feature_size = numeric_feature_count
            self.categorical_embedding_size = 16  # Embedding dimension for categorical features
            self.categorical_feature_count = categorical_feature_count
            
            # Feature processors
            if self.numeric_feature_size > 0:
                self.numeric_processor = torch.nn.Sequential(
                    torch.nn.Linear(self.numeric_feature_size, 64),
                    torch.nn.ReLU(),
                    torch.nn.Dropout(0.1)
                )
            
            # Categorical embeddings
            self.categorical_embedders = torch.nn.ModuleDict()
            if categorical_feature_count > 0:
                for feature in categorical_features:
                    try:
                        # Get unique values to determine embedding size
                        unique_values = set()
                        for example in train_dataset:
                            if feature in example and example[feature] is not None:
                                unique_values.add(str(example[feature]))
                        
                        # Only create embeddings if we have values
                        if unique_values:
                            vocab_size = max(2, len(unique_values) + 1)  # +1 for unknown values, minimum 2
                            self.categorical_embedders[feature] = torch.nn.Embedding(
                                vocab_size, self.categorical_embedding_size, padding_idx=0
                            )
                        else:
                            print(f"Warning: No values found for feature '{feature}', skipping")
                    except Exception as e:
                        print(f"Warning: Could not create embedding for feature '{feature}': {e}")
            
            # Calculate the total size of all categorical embeddings
            self.total_categorical_size = len(self.categorical_embedders) * self.categorical_embedding_size
            
            # Final classifier
            combined_size = hidden_size
            if self.numeric_feature_size > 0:
                combined_size += 64  # Output from numeric processor
            if self.total_categorical_size > 0:
                combined_size += self.total_categorical_size
                
            print(f"Combined feature size: {combined_size}")
            
            # Create the classifier with the correct input size
            self.classifier = torch.nn.Sequential(
                torch.nn.Linear(hidden_size, 256),  # Use only hidden_size as input dimension
                torch.nn.ReLU(),
                torch.nn.Dropout(0.2),
                torch.nn.Linear(256, num_labels)
            )
            
            # Store config for loss calculation
            self.config = self.transformer.config
        
        def forward(self, input_ids, attention_mask, token_type_ids=None, numeric_features=None, categorical_features=None, labels=None):
            # Process text through transformer
            transformer_outputs = self.transformer.base_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids if token_type_ids is not None else None,
                return_dict=True
            )
            
            pooled_output = transformer_outputs.last_hidden_state[:, 0, :]  # CLS token
            
            # For now, just use the transformer output directly
            combined_output = pooled_output
            
            # Print shape for debugging
            # print(f"Combined output shape: {combined_output.shape}")
            
            # Final classification
            logits = self.classifier(combined_output)
            
            loss = None
            if labels is not None:
                loss_fct = torch.nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, 2), labels.view(-1))  # Hardcode 2 for binary classification
            
            return transformers.modeling_outputs.SequenceClassifierOutput(
                loss=loss,
                logits=logits,
                hidden_states=transformer_outputs.hidden_states,
                attentions=transformer_outputs.attentions,
            )
    
    # Initialize custom model
    custom_model = CombinedFeaturesModel(
        model_name=model_name,
        num_labels=2,
        numeric_feature_count=len(numeric_features),
        categorical_feature_count=len(categorical_features)
    )

    # Calculate class weights
    labels = train_dataset['label']
    class_counts = np.bincount(labels)
    total_samples = len(labels)
    class_weights = torch.FloatTensor([total_samples / (len(class_counts) * count) for count in class_counts])

    # Define training arguments
    training_args = TrainingArguments(
        output_dir=f"{base_output_dir}/{fold_idx}",
        learning_rate=2e-6,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=8,
        num_train_epochs=40,
        weight_decay=0.01,
        evaluation_strategy="steps",
        eval_steps=100,
        metric_for_best_model="f1",
        report_to="wandb",
        save_strategy="no",
        run_name=f"{fold_idx}",
        warmup_ratio=0.1,
        logging_steps=4,
        lr_scheduler_type="cosine"
    )

    # Custom data collator to handle additional features
    class CustomDataCollator:
        def __init__(self, tokenizer):
            self.tokenizer = tokenizer
            
        def __call__(self, features):
            # Handle text features with tokenizer's default collator
            text_inputs = {
                'input_ids': [f['input_ids'] for f in features],
                'attention_mask': [f['attention_mask'] for f in features]
            }
            
            # Add token_type_ids if present
            if 'token_type_ids' in features[0]:
                text_inputs['token_type_ids'] = [f['token_type_ids'] for f in features]
                
            # Collate with tokenizer
            batch = self.tokenizer.pad(
                text_inputs,
                padding='longest',
                return_tensors='pt'
            )
            
            # Add numeric features
            if numeric_features:
                numeric_inputs = []
                for f in features:
                    feature_values = []
                    for feature in numeric_features:
                        feature_key = f"numeric_{feature}"
                        if feature_key in f:
                            feature_values.append(f[feature_key])
                        else:
                            feature_values.append(0.0)  # Default value if missing
                    numeric_inputs.append(feature_values)
                batch['numeric_features'] = torch.tensor(numeric_inputs, dtype=torch.float)
            
            # Add categorical features
            if categorical_features:
                batch['categorical_features'] = {}
                for feature in categorical_features:
                    try:
                        # Create vocabulary mapping for categorical features
                        feature_key = f"categorical_{feature}"
                        if all(feature_key in f for f in features):
                            feature_values = [f[feature_key] for f in features]
                            # Convert to indices - this is a simplified approach
                            vocab = {val: idx+1 for idx, val in enumerate(set(feature_values))}
                            indices = [vocab.get(val, 0) for val in feature_values]  # 0 for unknown
                            batch['categorical_features'][feature] = torch.tensor(indices, dtype=torch.long)
                    except Exception as e:
                        print(f"Warning: Error processing categorical feature '{feature}': {e}")
            
            # Add labels
            if 'label' in features[0]:
                batch['labels'] = torch.tensor([f['label'] for f in features], dtype=torch.long)
                
            return batch
    
    # Create custom weighted trainer
    class WeightedFeaturesTrainer(Trainer):
        def __init__(self, class_weights, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.class_weights = class_weights.to(self.args.device)

        def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
            # Ignore any additional keyword arguments
            # This will handle num_items_in_batch and any other unexpected args
            
            labels = inputs.pop("labels")
            outputs = model(**inputs)
            logits = outputs.logits
            
            # Apply weighted cross entropy loss
            loss_fct = torch.nn.CrossEntropyLoss(weight=self.class_weights)
            
            # Access num_labels safely whether model is wrapped in DataParallel or not
            num_labels = 2  # Hardcode to 2 since we know it's a binary classifier
            
            loss = loss_fct(logits.view(-1, num_labels), labels.view(-1))
            
            return (loss, outputs) if return_outputs else loss

    # Initialize trainer with custom model and data collator
    collator = CustomDataCollator(tokenizer)
    
    trainer = WeightedFeaturesTrainer(
        class_weights=class_weights,
        model=custom_model,
        args=training_args,
        train_dataset=processed_train,
        eval_dataset=processed_val,
        compute_metrics=compute_metrics,
        data_collator=collator
    )

    # Prepare everything with accelerator
    trainer = accelerator.prepare(trainer)

    # Train the model
    trainer.train()
    
    # Evaluate on validation set
    eval_results = trainer.evaluate()
    
    return eval_results

def cross_validate(n_folds: int = 5):
    # Initialize accelerator
    accelerator = Accelerator()
    
    # Load all datasets
    train_dataset = load_halawi_data(split="train")
    val_dataset = load_halawi_data(split="validation")
    test_dataset = load_halawi_data(split="test")
    
    # combine train and val
    train_dataset = concatenate_datasets([train_dataset, val_dataset])
    
    # make resolution column label 1 if resolution is 1 or 1.0, 0 otherwise
    train_dataset = train_dataset.map(lambda x: {"label": int(x["resolution"]) })
    test_dataset = test_dataset.map(lambda x: {"label": int(x["resolution"])})
    
    # shuffle both datasets
    train_dataset = train_dataset.shuffle(seed=42)
    test_dataset = test_dataset.shuffle(seed=42)
    
    # Balance the training dataset
    num1s = np.sum(train_dataset['label'])
    num0s = len(train_dataset['label']) - num1s
    
    if num1s > num0s:
        # keep num0s number of 1s
        indices_1s = [i for i, label in enumerate(train_dataset['label']) if label == 1]
        indices_0s = [i for i, label in enumerate(train_dataset['label']) if label == 0]
        
        # Take only num0s samples from the 1s
        sampled_1s_indices = indices_1s[:int(num0s)]
        only1s = train_dataset.select(sampled_1s_indices)
        only0s = train_dataset.select(indices_0s)
        train_dataset = concatenate_datasets([only1s, only0s])
    else:
        # only keep num1s number of 0s
        indices_1s = [i for i, label in enumerate(train_dataset['label']) if label == 1]
        indices_0s = [i for i, label in enumerate(train_dataset['label']) if label == 0]
        
        # Take only num1s samples from the 0s
        sampled_0s_indices = indices_0s[:int(num1s)]
        only0s = train_dataset.select(sampled_0s_indices)
        only1s = train_dataset.select(indices_1s)
        train_dataset = concatenate_datasets([only0s, only1s])
    
    # Balance the test dataset
    test_num1s = np.sum(test_dataset['label'])
    test_num0s = len(test_dataset['label']) - test_num1s
    
    if test_num1s > test_num0s:
        # keep test_num0s number of 1s
        test_indices_1s = [i for i, label in enumerate(test_dataset['label']) if label == 1]
        test_indices_0s = [i for i, label in enumerate(test_dataset['label']) if label == 0]
        
        # Take only test_num0s samples from the 1s
        test_sampled_1s_indices = test_indices_1s[:int(test_num0s)]
        test_only1s = test_dataset.select(test_sampled_1s_indices)
        test_only0s = test_dataset.select(test_indices_0s)
        test_dataset = concatenate_datasets([test_only1s, test_only0s])
    else:
        # only keep test_num1s number of 0s
        test_indices_1s = [i for i, label in enumerate(test_dataset['label']) if label == 1]
        test_indices_0s = [i for i, label in enumerate(test_dataset['label']) if label == 0]
        
        # Take only test_num1s samples from the 0s
        test_sampled_0s_indices = test_indices_0s[:int(test_num1s)]
        test_only0s = test_dataset.select(test_sampled_0s_indices)
        test_only1s = test_dataset.select(test_indices_1s)
        test_dataset = concatenate_datasets([test_only0s, test_only1s])
        
    train_dataset = train_dataset.shuffle(seed=42)
    test_dataset = test_dataset.shuffle(seed=42)
    
    # Add this diagnostic print
    train_label_counts = np.bincount(train_dataset['label'])
    test_label_counts = np.bincount(test_dataset['label'])
    if accelerator.is_main_process:
        print(f"Train label counts: {train_label_counts}")
        print(f"Test label counts: {test_label_counts}")
    
    # Initialize wandb only on the main process
    if accelerator.is_main_process:
        wandb.init(project="forecasting-classifier", name="halawi")
    
    # Store results for each fold
    fold_results = []
    model_name = "microsoft/deberta-v3-large"
    
    # Combine train and validation datasets for cross-validation
    combined_train = concatenate_datasets([train_dataset])
    
    try:
        # Train final model on all training data and evaluate on validation set
        if accelerator.is_main_process:
            print("Training final model on all training data...")
        train_all = concatenate_datasets([train_dataset])
        final_results = train_fold(
            train_all,
            test_dataset,
            model_name,
            fold_idx="metaculus",
            base_output_dir="./results"
        )
        
        # Log results only on the main process
        if accelerator.is_main_process and final_results is not None:
            wandb.log({
                "final_accuracy": final_results['eval_accuracy'],
                "final_f1": final_results['eval_f1'],
                "final_precision": final_results['eval_precision'],
                "final_recall": final_results['eval_recall']
            })
    except Exception as e:
        print(f"Error in training: {e}")
        final_results = None
    
    if accelerator.is_main_process:
        wandb.finish()
    
    return final_results

if __name__ == "__main__":
    # Enable deterministic behavior for reproducibility across GPUs
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    results = cross_validate(n_folds=5)
    if results:
        print("\nFinal Model Performance on Test Set:")
        print(results)