from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
import wandb
import logging
import torch
import time
from accelerate import Accelerator
import os
import torch.distributed as dist
import json
import argparse
import shutil
from sklearn.metrics import accuracy_score, f1_score, classification_report
import re
from tqdm import tqdm


def parse_binary_prediction(response):
    """Parse binary prediction (0 or 1) from model response"""
    # Clean the response
    response = response.strip().lower()
    
    # Try to extract integer directly
    try:
        # Look for single digit at the beginning
        match = re.search(r'^[^\d]*(\d)[^\d]*', response)
        if match:
            value = int(match.group(1))
            if value in [0, 1]:
                return value
        
        # Try direct parsing
        value = int(response)
        if value in [0, 1]:
            return value
    except (ValueError, TypeError):
        pass
    
    # Fallback: look for keywords
    if any(word in response for word in ['yes', 'true', '1', 'one', 'positive', 'True']):
        return 1
    elif any(word in response for word in ['no', 'false', '0', 'zero', 'negative', 'False']):
        return 0
    
    # Default to 0 if unclear
    return 0


def parse_float_prediction(response):
    """Parse float prediction from model response"""
    # Clean the response
    response = response.strip()
    
    # Try to extract float number
    try:
        # Look for decimal numbers
        match = re.search(r'(\d+\.?\d*)', response)
        if match:
            return float(match.group(1))
        
        # Try direct parsing
        return float(response)
    except (ValueError, TypeError):
        pass
    
    # Fallback: return 0 if parsing fails
    return 0.0


def predict_readmission(example, model, tokenizer):
    """Predict readmission task (binary classification: 0 or 1)"""
    # Use the existing messages structure directly
    messages = example['messages'][:2]  # system and user messages only
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )
    
    # Generate response
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=5,  # Short response expected
            do_sample=False,
        )
    
    # Extract response
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
    full_response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    
    # Parse prediction
    predicted_label = parse_binary_prediction(full_response)
    true_label = int(example['messages'][2]['content'])
    
    return {
        "predicted_label": predicted_label,
        "true_label": true_label,
        "model_response": full_response,
        "correct": predicted_label == true_label
    }


def predict_mortality(example, model, tokenizer):
    """Predict mortality task (binary classification: 0 or 1)"""
    # Use the existing messages structure directly
    messages = example['messages'][:2]  # system and user messages only
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )
    
    # Generate response
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=10,
            do_sample=False,
        )
    
    # Extract response
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
    full_response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    
    # Parse prediction
    predicted_label = parse_binary_prediction(full_response)
    true_label = int(example['messages'][2]['content'])
    
    return {
        "predicted_label": predicted_label,
        "true_label": true_label,
        "model_response": full_response,
        "correct": predicted_label == true_label
    }


def predict_period(example, model, tokenizer):
    """Predict length of stay task (regression: float value)"""
    # Use the existing messages structure directly
    messages = example['messages'][:2]  # system and user messages only
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )
    
    # Generate response
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=5,
            do_sample=False
        )
    
    # Extract response
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
    full_response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    
    # Parse prediction
    predicted_value = parse_float_prediction(full_response)
    true_value = float(example['messages'][2]['content'])
    
    return {
        "predicted_value": predicted_value,
        "true_value": true_value,
        "model_response": full_response,
        "error": abs(predicted_value - true_value)
    }


def calculate_metrics(results, task_type):
    """Calculate accuracy and F1 scores based on task type"""
    if not results:
        print("Warning: No results to calculate metrics for")
        return {}
    
    if task_type in ['readmission', 'mortality']:
        # Binary classification tasks
        true_labels = [r['true_label'] for r in results]
        predicted_labels = [r['predicted_label'] for r in results]
        
        accuracy = accuracy_score(true_labels, predicted_labels)
        f1 = f1_score(true_labels, predicted_labels, average='binary')
        f1_macro = f1_score(true_labels, predicted_labels, average='macro')
        
        return {
            "accuracy": accuracy,
            "f1_binary": f1,
            "f1_macro": f1_macro,
            "classification_report": classification_report(true_labels, predicted_labels, target_names=['0', '1'])
        }
    
    elif task_type == 'period':
        # Regression task - calculate MAE, RMSE, and accuracy within tolerance
        errors = [r['error'] for r in results]
        mae = sum(errors) / len(errors)
        rmse = (sum(e**2 for e in errors) / len(errors))**0.5
        
        # Accuracy within 1 day tolerance
        tolerance = 1.0
        within_tolerance = sum(1 for e in errors if e <= tolerance)
        accuracy = within_tolerance / len(errors)
        
        return {
            "mae": mae,
            "rmse": rmse,
            "accuracy_within_1day": accuracy,
            "mean_error": mae
        }
    
    return {}


def split_dataset(dataset_path, test_size=200):
    """Split dataset into train and test sets"""
    logger = logging.getLogger(__name__)
    
    # Load the dataset
    with open(dataset_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    logger.info(f"Loaded dataset with {len(data)} samples")
    
    # Split the dataset
    test_data = data[:test_size]
    train_data = data[test_size:]
    
    logger.info(f"Split dataset: {len(train_data)} train samples, {len(test_data)} test samples")
    
    # Create test_data directory if it doesn't exist
    os.makedirs('./test_data', exist_ok=True)
    
    # Get input filename
    input_filename = os.path.splitext(os.path.basename(dataset_path))[0]
    
    # Save test dataset
    test_path = f'./test_data/{input_filename}_test.json'
    with open(test_path, 'w', encoding='utf-8') as f:
        json.dump(test_data, f, indent=2, ensure_ascii=False)
    
    logger.info(f"Test dataset saved to {test_path}")
    
    # Save train dataset (overwrite original or create new)
    train_path = dataset_path.replace('.json', '_train.json')
    with open(train_path, 'w', encoding='utf-8') as f:
        json.dump(train_data, f, indent=2, ensure_ascii=False)
    
    logger.info(f"Train dataset saved to {train_path}")
    
    return train_path, test_path


def test_model(model, tokenizer, test_data_path, task_type, args, logger):
    """Test the trained model on the test dataset"""
    logger.info(f"Starting model testing on {test_data_path}")
    
    # Load test data
    with open(test_data_path, 'r', encoding='utf-8') as f:
        test_data = json.load(f)
    
    logger.info(f"Loaded test dataset with {len(test_data)} samples")
    
    results = []
    
    # Process test samples
    progress_bar = tqdm(total=len(test_data), desc=f"Testing {task_type} samples")
    
    for i, sample in enumerate(test_data):
        try:
            # Choose prediction function based on task type
            if task_type == 'readmission':
                result = predict_readmission(sample, model, tokenizer)
            elif task_type == 'mortality':
                result = predict_mortality(sample, model, tokenizer)
            elif task_type == 'period':
                result = predict_period(sample, model, tokenizer)
            else:
                raise ValueError(f"Unknown task type: {task_type}")
            
            results.append(result)
            
            # Clear GPU cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
        except Exception as e:
            logger.error(f"Error processing test sample {i}: {str(e)}")
            continue
        
        progress_bar.update(1)
    
    progress_bar.close()
    
    # Calculate metrics
    metrics = calculate_metrics(results, task_type)
    
    # Save test results with hyperparameters in filename
    input_filename = os.path.splitext(os.path.basename(test_data_path))[0]
    hyperparams_suffix = f"r{args.lora_r}_lr{args.sft_lr}_ep{args.sft_epochs}_bs{args.batch_size}_gas{args.gradient_accumulation_steps}"
    test_output_path = f'./outputs/{input_filename}_test_results_{hyperparams_suffix}.json'
    
    final_output = {
        "task_type": task_type,
        "test_data_path": test_data_path,
        "summary": {
            "total_samples": len(results),
            **metrics
        },
        "results": results
    }
    
    with open(test_output_path, 'w', encoding='utf-8') as f:
        json.dump(final_output, f, indent=2, ensure_ascii=False)
    
    logger.info(f"Test results saved to {test_output_path}")
    
    # Print metrics
    logger.info(f"\nTest Results for {task_type}:")
    logger.info(f"Total samples: {len(results)}")
    
    for metric, value in metrics.items():
        if isinstance(value, float):
            if 'accuracy' in metric.lower():
                logger.info(f"{metric}: {value:.2%}")
            else:
                logger.info(f"{metric}: {value:.4f}")
        elif metric == 'classification_report':
            logger.info(f"\nClassification Report:\n{value}")
    
    return final_output


def main(args):
    # Setup logging
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    logger = logging.getLogger(__name__)
    
    # Get local rank for distributed training
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    
    # Initialize wandb only on main process
    if local_rank == 0:
        hyperparams_str = f"r{args.lora_r}_lr{args.sft_lr}_ep{args.sft_epochs}_bs{args.batch_size}_gas{args.gradient_accumulation_steps}"
        if args.enable_dpo:
            hyperparams_str += f"_dpolr{args.dpo_lr}_dpoep{args.dpo_epochs}"
        run_name = f"{args.task_type}_{hyperparams_str}"
        run = wandb.init(project="qwen3-0.6B-sft-dpo", name=run_name)

    # Split dataset if requested
    if args.split_dataset:
        train_path, test_path = split_dataset(args.dataset_path, args.test_size)
        dataset_path = train_path
    else:
        dataset_path = args.dataset_path
        test_path = None

    # Load model and tokenizer
    logger.info("Loading model and tokenizer...")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        torch_dtype=torch.bfloat16,
        device_map={"":local_rank}
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Configure LoRA
    lora_config = LoraConfig(
        r=args.lora_r, lora_alpha=args.lora_r*2,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, lora_config)
    
    # SFT Training
    logger.info("Starting SFT training...")
    sft_dataset = load_dataset("json", data_files=dataset_path, split="train")
    
    # Create output directory based on dataset name and hyperparameters
    input_filename = os.path.splitext(os.path.basename(dataset_path))[0]
    hyperparams_suffix = f"r{args.lora_r}_lr{args.sft_lr}_ep{args.sft_epochs}_bs{args.batch_size}_gas{args.gradient_accumulation_steps}"
    output_dir = f"./trained_models/{input_filename}_Qwen3-0.6B-SFT_{hyperparams_suffix}"
    
    sft_config = SFTConfig(
        output_dir=output_dir,
        run_name="sft_training_run",
        per_device_train_batch_size=args.batch_size, 
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        num_train_epochs=args.sft_epochs, 
        learning_rate=args.sft_lr,
        bf16=True, 
        logging_steps=10, 
        save_steps=500,
        ddp_find_unused_parameters=True,
    )
    sft_trainer = SFTTrainer(model=model, args=sft_config, processing_class=tokenizer, train_dataset=sft_dataset)
    sft_trainer.train()
    logger.info("SFT training completed!")

    # Clean up SFT trainer and dataset to free memory
    del sft_trainer, sft_dataset
    torch.cuda.empty_cache()
    
    # Wait for 5 seconds to ensure memory is freed
    logger.info("Waiting 5 seconds before starting DPO training...")
    time.sleep(5)

    # DPO Training (if enabled)
    if args.enable_dpo:
        logger.info("Starting DPO training...")
        dpo_dataset = load_dataset("json", data_files=args.dpo_dataset_path, split="train")
        dpo_dataset = dpo_dataset.shuffle()
        
        dpo_hyperparams_suffix = f"r{args.lora_r}_sftlr{args.sft_lr}_dpolr{args.dpo_lr}_sftep{args.sft_epochs}_dpoep{args.dpo_epochs}_bs{args.batch_size}_gas{args.gradient_accumulation_steps}"
        dpo_output_dir = f"./trained_models/{input_filename}_Qwen3-0.6B-SFT-DPO_{dpo_hyperparams_suffix}"
        dpo_config = DPOConfig(
            output_dir=dpo_output_dir,
            run_name="dpo_training_run",
            per_device_train_batch_size=args.batch_size, 
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            num_train_epochs=args.dpo_epochs, 
            learning_rate=args.dpo_lr,
            bf16=True, 
            logging_steps=10, 
            save_steps=500, 
            optim="adamw_torch", 
            warmup_ratio=0.1, 
            ddp_find_unused_parameters=True,
            label_names=[]
        )
        dpo_trainer = DPOTrainer(model=model, args=dpo_config, processing_class=tokenizer, train_dataset=dpo_dataset)
        dpo_trainer.train()
        logger.info("DPO training completed!")
        
        # Clean up DPO trainer
        del dpo_trainer, dpo_dataset
        torch.cuda.empty_cache()
        time.sleep(5)

    # Test the model if test dataset is available
    if test_path and args.test_model:
        logger.info("Starting model testing...")
        
        # Load the trained model for testing
        if args.enable_dpo:
            model_path = dpo_output_dir
        else:
            model_path = output_dir
            
        # Find the latest checkpoint if model_path contains checkpoints
        if os.path.exists(model_path):
            # Check if there are checkpoint subdirectories
            checkpoint_dirs = [d for d in os.listdir(model_path) if d.startswith('checkpoint-')]
            if checkpoint_dirs:
                # Sort by checkpoint number and get the latest
                latest_checkpoint = sorted(checkpoint_dirs, key=lambda x: int(x.split('-')[1]))[-1]
                model_path = os.path.join(model_path, latest_checkpoint)
                logger.info(f"Loading model from checkpoint: {model_path}")
        
        # Load the trained model
        trained_model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            device_map="auto"
        ).eval()
        
        trained_tokenizer = AutoTokenizer.from_pretrained(model_path)
        if trained_tokenizer.pad_token is None:
            trained_tokenizer.pad_token = trained_tokenizer.eos_token
        
        # Test the model
        test_results = test_model(trained_model, trained_tokenizer, test_path, args.task_type, args, logger)
        
        logger.info("Model testing completed!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train QWEN model with SFT and optional DPO')
    
    # Dataset arguments
    parser.add_argument('--dataset_path', type=str, required=True,
                      help='Path to the training dataset JSON file')
    parser.add_argument('--split_dataset', action='store_true',
                      help='Split dataset into train/test sets')
    parser.add_argument('--test_size', type=int, default=200,
                      help='Number of samples to use for testing (default: 200)')
    parser.add_argument('--task_type', type=str, choices=['readmission', 'mortality', 'period'], 
                      default='readmission', help='Type of clinical task')
    
    # Model arguments
    parser.add_argument('--model_name', type=str, default='Qwen/Qwen3-0.6B',
                      help='Name of the QWEN model to use')
    
    # LoRA hyperparameters
    parser.add_argument('--lora_r', type=int, default=16,
                      help='LoRA rank (default: 16)')
    parser.add_argument('--lora_dropout', type=float, default=0.05,
                      help='LoRA dropout rate (default: 0.05)')
    
    # SFT hyperparameters
    parser.add_argument('--sft_lr', type=float, default=4e-5,
                      help='SFT learning rate (default: 4e-5)')
    parser.add_argument('--sft_epochs', type=int, default=1,
                      help='Number of SFT training epochs (default: 1)')
    parser.add_argument('--batch_size', type=int, default=1,
                      help='Per device batch size (default: 1)')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=8,
                      help='Gradient accumulation steps (default: 8)')
    
    # DPO arguments
    parser.add_argument('--enable_dpo', action='store_true',
                      help='Enable DPO training after SFT')
    parser.add_argument('--dpo_dataset_path', type=str, default=None,
                      help='Path to DPO dataset (required if enable_dpo is True)')
    parser.add_argument('--dpo_lr', type=float, default=4e-5,
                      help='DPO learning rate (default: 4e-5)')
    parser.add_argument('--dpo_epochs', type=int, default=1,
                      help='Number of DPO training epochs (default: 1)')
    
    # Testing arguments
    parser.add_argument('--test_model', action='store_true',
                      help='Test the trained model on test set')
    
    args = parser.parse_args()
    
    # Validate arguments
    if args.enable_dpo and args.dpo_dataset_path is None:
        raise ValueError("dpo_dataset_path must be provided when enable_dpo is True")
    
    # Create necessary directories
    os.makedirs('./trained_models', exist_ok=True)
    os.makedirs('./test_data', exist_ok=True)
    os.makedirs('./outputs', exist_ok=True)
    
    # Run training
    main(args)