from sentence_transformers import InputExample, CrossEncoder
from torch.utils.data import DataLoader, Dataset
import torch
import logging
import os
import wandb
import argparse
from datetime import datetime
from datasets import load_dataset
import random
import numpy as np
from tqdm.auto import tqdm
from transformers import TrainingArguments, Trainer
from typing import Dict, List

# Disable PyTorch 2.0 Compiler optimizations
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.cache_size_limit = 0

logging.basicConfig(format='%(asctime)s - %(message)s',
                   datefmt='%Y-%m-%d %H:%M:%S',
                   level=logging.INFO)
logger = logging.getLogger(__name__)

os.environ["DISABLE_MLFLOW_INTEGRATION"] = "TRUE"
os.environ["DISABLE_TELEMETRY"] = "TRUE"
os.environ["HF_METRICS_ONLINE"] = "FALSE"

class CrossEncoderDataset(Dataset):
    def __init__(self, samples: List[InputExample], model: CrossEncoder):
        self.samples = samples
        self.model = model
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        encoded = self.model.tokenizer(
            text=sample.texts[0],
            text_pair=sample.texts[1],
            truncation=True,
            max_length=self.model.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        return {
            'input_ids': encoded['input_ids'][0],
            'attention_mask': encoded['attention_mask'][0],
            'labels': torch.tensor(sample.label, dtype=torch.float)
        }

def parse_args():
    parser = argparse.ArgumentParser(description='Train a cross-encoder model')
    
    # Model parameters
    parser.add_argument('--model_name', type=str, default='distilroberta-base',
                      help='Base model to use')
    parser.add_argument('--num_epochs', type=int, default=3,
                      help='Number of epochs to train')
    parser.add_argument('--batch_size', type=int, default=16,
                      help='Batch size for training')
    parser.add_argument('--max_length', type=int, default=512,
                      help='Maximum sequence length')
    parser.add_argument('--learning_rate', type=float, default=2e-5,
                      help='Learning rate for training')
    
    # Dataset parameters
    parser.add_argument('--dataset_path', type=str, required=True,
                      help='HuggingFace dataset path')
    parser.add_argument('--score_columns', type=str, nargs='+', default=None,
                      help='Columns to use for scores')
    parser.add_argument('--max_rows', type=int, default=None,
                      help='Maximum number of rows to process')
    
    # Output settings
    parser.add_argument('--output_dir', type=str, default='checkpoints',
                      help='Output directory for checkpoints')
    parser.add_argument('--use_wandb', action='store_true',
                      help='Whether to use wandb logging')
    
    # Other settings
    parser.add_argument('--seed', type=int, default=42,
                      help='Random seed')
    
    args = parser.parse_args()
    
    # Set up output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Set up logging directory
    log_dir = os.path.join(args.output_dir, 'logs')
    os.makedirs(log_dir, exist_ok=True)
    
    return args

def load_data(args):
    """Loads and processes the dataset"""
    logger.info("Loading dataset...")
    dataset = load_dataset(args.dataset_path)
    data = dataset["data"]
    
    max_rows = args.max_rows if args.max_rows is not None else len(data)
    logger.info(f"Processing {max_rows} rows...")
    
    all_samples = []
    instructions = data['instruction'][:max_rows]
    all_samples_list = data['samples'][:max_rows]
    
    # Process scores
    score_arrays = {
        col: np.array([scores for scores in data[col][:max_rows]])
        for col in args.score_columns
    }
    
    # Create samples
    total_samples = sum(len(samples) for samples in all_samples_list)
    logger.info(f"Creating {total_samples} samples...")
    
    for idx in tqdm(range(max_rows)):
        instruction = instructions[idx]
        samples = all_samples_list[idx]
        
        # Calculate scores
        scores = []
        for column in args.score_columns:
            column_scores = score_arrays[column][idx]
            scores.append(column_scores)
        
        # Average scores if multiple columns
        final_scores = np.mean(scores, axis=0) if len(scores) > 1 else scores[0]
        
        # Create samples
        all_samples.extend([
            InputExample(texts=[instruction, sample], label=float(score))
            for sample, score in zip(samples, final_scores)
        ])
    
    # Split data
    random.shuffle(all_samples)
    train_size = int(0.9 * len(all_samples))
    train_samples = all_samples[:train_size]
    eval_samples = all_samples[train_size:]
    
    logger.info(f"Created {len(train_samples)} training and {len(eval_samples)} evaluation samples")
    return train_samples, eval_samples

def main():
    args = parse_args()
    
    # Set random seed
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    
    # Set up run name and logging
    run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(args.output_dir, 'logs', f'training_{run_name}.log')
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
    logger.addHandler(file_handler)
    
    # Format dataset name for checkpoint path
    dataset_name = args.dataset_path.replace('/', '->')
    checkpoint_dir = os.path.join(args.output_dir, dataset_name)
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Load data
    train_samples, eval_samples = load_data(args)
    
    # Initialize model
    model = CrossEncoder(
        args.model_name,
        num_labels=1,
        max_length=args.max_length,
        activation_fn=None  # Use raw outputs for regression
    )
    
    # Create datasets
    train_dataset = CrossEncoderDataset(train_samples, model)
    eval_dataset = CrossEncoderDataset(eval_samples, model)
    
    # Define training arguments
    training_args = TrainingArguments(
        output_dir=checkpoint_dir,  # Changed to use the formatted checkpoint directory
        num_train_epochs=args.num_epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        warmup_ratio=0.1,
        logging_dir=os.path.join(args.output_dir, 'logs'),
        logging_steps=10,
        save_strategy="epoch",
        evaluation_strategy="epoch",
        report_to=["wandb"] if args.use_wandb else [],
        run_name=run_name,
        # Disable compiler optimizations
        torch_compile=False,
        bf16=False,
        fp16=False
    )
    
    # Initialize wandb if requested
    if args.use_wandb:
        wandb.init(project="cross-encoder-training", name=run_name)
    
    try:
        # Create trainer and train
        trainer = Trainer(
            model=model.model,  # Get the underlying transformer model
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset
        )
        
        logger.info("Starting training...")
        trainer.train()
        
        # Save the final checkpoint with the specified format
        final_checkpoint = os.path.join(checkpoint_dir, f"{run_name}.ckpt")
        model.save(final_checkpoint)
        logger.info(f"Training completed. Final model checkpoint saved to: {final_checkpoint}")
        
    except Exception as e:
        logger.error(f"Error during training: {str(e)}")
        raise
    finally:
        if args.use_wandb:
            wandb.finish()

if __name__ == "__main__":
    main()