import os
import pandas as pd
from datasets import Dataset, load_dataset
from typing import Any, Dict, List, Optional, Union
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)
import torch
from dataclasses import dataclass
from typing import Dict, List, Any, Optional
from transformers.utils import PaddingStrategy
import time
import shutil
from pathlib import Path
from tqdm import tqdm
import json
# Add accelerate imports
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

# ---------------------------
# Configuration
# ---------------------------
model_name = "meta-llama/Llama-3.2-1B-Instruct"  # Example model name

# ---------------------------
# Accelerator Setup
# ---------------------------
# Create ddp_kwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

# Initialize accelerator without mixed precision
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs], mixed_precision="no")

# Print current device and process information
print(f"Device: {accelerator.device}")
print(f"Distributed type: {accelerator.distributed_type}")
print(f"Local process rank: {accelerator.local_process_index}")
print(f"Num processes: {accelerator.num_processes}")

# Only the main process should log detailed information
is_main_process = accelerator.is_main_process

# ---------------------------
# Load Model and Tokenizer
# ---------------------------
def setup_model_and_tokenizer(model_name):
    """Setup model and tokenizer with proper configuration"""
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    if tokenizer.pad_token is None:
        # Use the EOS token as the pad token instead of adding a new one
        tokenizer.pad_token = tokenizer.eos_token
    
    if is_main_process:
        print(f"EOS token ID: {tokenizer.eos_token_id}")
        print(f"Vocabulary size: {len(tokenizer)}")
        
    tokenizer.truncation_side = "left"
    tokenizer.model_max_length = 4096

    # Create model
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=1,
    )
    model.config.pad_token_id = tokenizer.pad_token_id
    model.resize_token_embeddings(len(tokenizer))
    
    if is_main_process:
        print("Using full model fine-tuning...")
        
    return model, tokenizer

# ---------------------------
# Data Preprocessing for ShareGPT Format Conversations
# ---------------------------
def process_sharegpt_conversations(dataset):
    """
    Process dataset where each row is a conversation in ShareGPT format
    with a final rating of + or - in the last assistant message
    """
    processed_conversations = []
    
    for idx, row in tqdm(enumerate(dataset), desc="Processing conversations", 
                         total=len(dataset), disable=not is_main_process):
        try:
            # Check if "conversations" field exists, otherwise use the row as the conversation
            if "conversations" in row:
                conversation = row["conversations"]
            else:
                # For the format where the whole row is the conversation array
                conversation = row
            
            # Get conversation ID from metadata or generate one
            if "id" in row:
                conversation_id = row["id"]
            else:
                conversation_id = f"conversation_{idx}"
            
            # Extract all user messages for the conversation text
            user_messages = [msg["content"] for msg in conversation if msg["role"] == "user"]
            conversation_text = " ".join(user_messages)
            
            # Check the final assistant message to determine the label
            assistant_messages = [msg for msg in conversation if msg["role"] == "assistant"]
            if not assistant_messages:
                if is_main_process:
                    print(f"Warning: No assistant messages found in conversation {conversation_id}")
                continue
                
            final_assistant_msg = assistant_messages[-1]["content"].strip()
            
            # Determine if the conversation is biased based on the final rating
            if final_assistant_msg == "+":
                label = 1  # Unbiased
            elif final_assistant_msg == "-":
                label = 0  # Biased
            else:
                # If the final message isn't a clear + or - rating, skip or make a judgment
                # Here, we'll be conservative and mark unclear ratings as biased
                if is_main_process:
                    print(f"Warning: Unclear final rating '{final_assistant_msg}' in conversation {conversation_id}")
                label = 0  # Treat as biased to be conservative
            
            processed_conversations.append({
                "conversation_text": conversation_text,
                "label": label,
                "conversation_id": conversation_id,
                "num_steps": len(user_messages)
            })
            
        except Exception as e:
            if is_main_process:
                print(f"Error processing conversation {idx}: {e}")
            continue
    
    if is_main_process:
        print(f"Successfully processed {len(processed_conversations)} conversations")
    return pd.DataFrame(processed_conversations)

def process_conversation_data(row):
    """Process a single conversation row for the model"""
    text = row["conversation_text"]
    
    tokenized = tokenizer(
        text,
        truncation=True,
        max_length=4096,
        padding=False
    )
    
    # Label is already binary: 0 if biased (-), 1 if unbiased (+)
    label = row["label"]
    
    return {
        "input_ids": tokenized["input_ids"],
        "attention_mask": tokenized["attention_mask"],
        "label": label
    }

# ---------------------------
# Load and Process Dataset
# ---------------------------
def load_and_process_dataset(data_path=None, dataset_name=None, dataset_config=None):
    """Load data either from a local file or HuggingFace dataset"""
    
    if data_path:
        # Load from local JSON file
        if is_main_process:
            print(f"Loading conversations from local file: {data_path}")
        try:
            with open(data_path, 'r') as f:
                data = json.load(f)
            
            # Check if it's a dict with a 'conversations' key or a list of conversations
            if isinstance(data, dict) and 'conversations' in data:
                raw_conversations = data['conversations']
            elif isinstance(data, list):
                raw_conversations = data
            else:
                raise ValueError("Unsupported JSON format - expected list or dict with 'conversations' key")
                
            if is_main_process:
                print(f"Loaded {len(raw_conversations)} conversations from file")
            
        except Exception as e:
            if is_main_process:
                print(f"Error loading file {data_path}: {e}")
            return None, None
            
    elif dataset_name:
        # Load from HuggingFace
        if is_main_process:
            print(f"Loading dataset: {dataset_name}" + (f" with config: {dataset_config}" if dataset_config else ""))
        try:
            if dataset_config:
                dataset = load_dataset(dataset_name, dataset_config)
            else:
                dataset = load_dataset(dataset_name)
                
            # Convert all splits to a list
            raw_conversations = []
            for split in dataset.keys():
                raw_conversations.extend(dataset[split])
                
            if is_main_process:
                print(f"Loaded {len(raw_conversations)} conversations from HuggingFace")
            
        except Exception as e:
            if is_main_process:
                print(f"Error loading dataset {dataset_name}: {e}")
            return None, None
    else:
        if is_main_process:
            print("Either data_path or dataset_name must be provided")
        return None, None
        
    # Process the conversations
    df_conversations = process_sharegpt_conversations(raw_conversations)
    
    # Convert to HuggingFace dataset and split
    conversation_dataset = Dataset.from_pandas(df_conversations).shuffle(seed=45)
    
    # Split into train and eval
    total_size = len(conversation_dataset)
    eval_size = max(1, int(total_size * 0.05))  # At least 1 example, up to 5% for eval
    
    if total_size <= 1:
        if is_main_process:
            print("Warning: Dataset too small to split into train/eval")
        train_dataset = conversation_dataset
        eval_dataset = conversation_dataset
    else:
        eval_dataset = conversation_dataset.select(range(eval_size))
        train_dataset = conversation_dataset.select(range(eval_size, total_size))
    
    if is_main_process:
        print("Tokenizing datasets...")
    tokenized_train = train_dataset.map(
        process_conversation_data,
        remove_columns=train_dataset.column_names,
        num_proc=8 if accelerator.num_processes == 1 else 1  # Use multiple processes only in single-GPU setup
    )
    if is_main_process:
        print(f"Training dataset size: {len(tokenized_train)}")
    
    tokenized_eval = eval_dataset.map(
        process_conversation_data,
        remove_columns=eval_dataset.column_names,
        num_proc=8 if accelerator.num_processes == 1 else 1
    )
    if is_main_process:
        print(f"Eval dataset size: {len(tokenized_eval)}")
    
    return tokenized_train, tokenized_eval

# ---------------------------
# Data Collator
# ---------------------------
@dataclass
class RewardDataCollatorWithPadding:
    tokenizer: AutoTokenizer
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        # Separate out labels
        labels = [f["label"] for f in features]
        
        # Prepare input features for padding
        batch_features = [
            {
                "input_ids": f["input_ids"],
                "attention_mask": f["attention_mask"],
            }
            for f in features
        ]
        
        batch = self.tokenizer.pad(
            batch_features,
            padding=True,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )

        # Convert labels to torch
        if hasattr(model, 'dtype'):
            batch["labels"] = torch.tensor(labels, dtype=model.dtype)
        elif hasattr(model, 'module') and hasattr(model.module, 'dtype'):
            batch["labels"] = torch.tensor(labels, dtype=model.module.dtype)
        else:
            # Default to bfloat16 (you can change this to float32 if needed)
            batch["labels"] = torch.tensor(labels, dtype=torch.bfloat16)
        return batch

# ---------------------------
# Custom Trainer for Reward Modeling
# ---------------------------
class RewardTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        if hasattr(model, 'dtype'):
            labels = inputs["labels"].to(model.dtype)
        elif hasattr(model, 'module') and hasattr(model.module, 'dtype'):
            labels = inputs["labels"].to(model.module.dtype)
        else:
            # Default to bfloat16 (you can change this to float32 if needed)
            labels = inputs["labels"].to(torch.bfloat16)
        
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"]
        )
        
        # Calculate loss using manual BCE implementation
        logits = outputs.logits.squeeze(-1)  # Shape: (batch_size,)
        probs = torch.sigmoid(logits)
        # Add epsilon (1e-10) for numerical stability
        loss = labels * torch.log(probs + 1e-10) + (1 - labels) * torch.log(1 - probs + 1e-10)
        loss = -torch.mean(loss)  # Negative because we want to minimize negative log likelihood
        
        return (loss, outputs) if return_outputs else loss

# ---------------------------
# Training with Accelerate
# ---------------------------
def train_with_accelerate(model, train_dataset, eval_dataset, tokenizer, args):
    """Train the model using Accelerate"""
    
    # Create a unique temporary directory for this run
    timestamp = int(time.time())
    tmp_dir = f"/tmp/conversation_prm_{timestamp}"
    if is_main_process:
        os.makedirs(tmp_dir, exist_ok=True)
    
    # Configure training arguments
    training_args = TrainingArguments(
        output_dir=tmp_dir,
        num_train_epochs=args.num_epochs if hasattr(args, 'num_epochs') else 2,
        per_device_train_batch_size=128, 
        per_device_eval_batch_size=128,
        gradient_accumulation_steps=args.gradient_accumulation_steps if hasattr(args, 'gradient_accumulation_steps') else 8,
        learning_rate=2e-5,  
        weight_decay=0.01,
        bf16=False,
        fp16=False,
        evaluation_strategy="steps",
        eval_steps=100,
        save_strategy="steps",
        save_steps=5000,
        save_total_limit=1,
        gradient_checkpointing=True,
        logging_steps=10,
        optim="adamw_torch",
        adam_beta1=0.9,  
        adam_beta2=0.95,  
        lr_scheduler_type="cosine",
        warmup_ratio=0.15,
        report_to="wandb" if is_main_process else "none",  # Only main process logs to wandb
        run_name=f"conversation-prm-full-{timestamp}",
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        push_to_hub=is_main_process,  # Only main process pushes to hub
        hub_model_id=args.output_model if hasattr(args, 'output_model') else "outcome_1B",
        overwrite_output_dir=True
    )
    
    # Initialize trainer
    trainer = RewardTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=RewardDataCollatorWithPadding(tokenizer=tokenizer, max_length=4096),
    )
    
    # Prepare everything for accelerate
    trainer = accelerator.prepare(trainer)
    
    # Train and evaluate
    if is_main_process:
        print("Starting training...")
    
    trainer.train()
    
    # Save the model (only main process)
    if is_main_process:
        if hasattr(args, 'output_model') and args.output_model:
            try:
                trainer.push_to_hub()
                tokenizer.push_to_hub(training_args.hub_model_id)
                print(f"Model successfully pushed to {training_args.hub_model_id}")
            except Exception as e:
                print(f"Error pushing to hub: {e}")
    
    # Clean up temporary directory (only main process)
    if is_main_process:
        try:
            if os.path.exists(tmp_dir):
                shutil.rmtree(tmp_dir)
        except Exception as e:
            print(f"Warning: Could not clean up temporary directory: {e}")
    
    return trainer

# ---------------------------
# Main function
# ---------------------------
def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Train a conversation-level bias classification model with Accelerate")
    parser.add_argument("--data_path", type=str, help="Path to local JSON data file", default="")
    parser.add_argument("--dataset", type=str, help="HuggingFace dataset name (optional)", default="")
    parser.add_argument("--config", type=str, help="Dataset config (optional)", default="")
    parser.add_argument("--output_model", type=str, help="HuggingFace model ID for output (optional)", default="outcome_1B")
    parser.add_argument("--model_name", type=str, help="Base model name (optional)", default="meta-llama/Llama-3.2-3B-Instruct")
    parser.add_argument("--batch_size", type=int, help="Per-device batch size", default=128)
    parser.add_argument("--num_epochs", type=int, help="Number of training epochs", default=2)
    parser.add_argument("--learning_rate", type=float, help="Learning rate", default=2e-5)
    parser.add_argument("--gradient_accumulation_steps", type=int, help="Gradient accumulation steps", default=8)
    
    args = parser.parse_args()
    
    # Set up model and tokenizer (will be initialized on the right device by Accelerate)
    global model, tokenizer
    model, tokenizer = setup_model_and_tokenizer(args.model_name)
    
    # Load and process dataset
    train_dataset, eval_dataset = load_and_process_dataset(
        data_path=args.data_path,
        dataset_name=args.dataset,
        dataset_config=args.config
    )
    
    if train_dataset is None or eval_dataset is None:
        if is_main_process:
            print("Error: Failed to load or process dataset")
        return
    
    # Train with Accelerate
    trainer = train_with_accelerate(model, train_dataset, eval_dataset, tokenizer, args)
    
    if is_main_process:
        print("Training completed successfully!")

if __name__ == "__main__":
    main()