import json
import torch
import time
import os
import re
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, classification_report
from transformers import AutoModelForCausalLM, AutoTokenizer

def save_checkpoint(processed_results, checkpoint_path):
    """Save current progress to a checkpoint file"""
    with open(checkpoint_path, 'w', encoding='utf-8') as f:
        json.dump(processed_results, f, indent=2, ensure_ascii=False)
    print(f"Saved checkpoint to {checkpoint_path}")

def predict_readmission(example, model, tokenizer):
    """Predict readmission task (binary classification: 0 or 1)"""
    # Use the existing messages structure directly
    messages = example['messages'][:2]  # system and user messages only
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )
    
    # Generate response
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=5,  # Short response expected
            do_sample=False,
        )
    
    # Extract response
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
    full_response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    
    # Parse prediction
    predicted_label = parse_binary_prediction(full_response)
    true_label = int(example['messages'][2]['content'])
    
    return {
        "predicted_label": predicted_label,
        "true_label": true_label,
        "model_response": full_response,
        "correct": predicted_label == true_label
    }

def predict_mortality(example, model, tokenizer):
    """Predict mortality task (binary classification: 0 or 1)"""
    # Use the existing messages structure directly
    messages = example['messages'][:2]  # system and user messages only
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )
    
    # Generate response
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=10,
            do_sample=False,
        )
    
    # Extract response
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
    full_response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    
    # Parse prediction
    predicted_label = parse_binary_prediction(full_response)
    true_label = int(example['messages'][2]['content'])
    
    return {
        "predicted_label": predicted_label,
        "true_label": true_label,
        "model_response": full_response,
        "correct": predicted_label == true_label
    }

def predict_period(example, model, tokenizer):
    """Predict length of stay task (regression: float value)"""
    # Use the existing messages structure directly
    messages = example['messages'][:2]  # system and user messages only
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )
    
    # Generate response
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=5,
            do_sample=False
        )
    
    # Extract response
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
    full_response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    
    # Parse prediction
    predicted_value = parse_float_prediction(full_response)
    true_value = float(example['messages'][2]['content'])
    
    return {
        "predicted_value": predicted_value,
        "true_value": true_value,
        "model_response": full_response,
        "error": abs(predicted_value - true_value)
    }

def parse_binary_prediction(response):
    """Parse binary prediction (0 or 1) from model response"""
    # Clean the response
    response = response.strip().lower()
    
    # Try to extract integer directly
    try:
        # Look for single digit at the beginning
        match = re.search(r'^[^\d]*(\d)[^\d]*', response)
        if match:
            value = int(match.group(1))
            if value in [0, 1]:
                return value
        
        # Try direct parsing
        value = int(response)
        if value in [0, 1]:
            return value
    except (ValueError, TypeError):
        pass
    
    # Fallback: look for keywords
    if any(word in response for word in ['yes', 'true', '1', 'one', 'positive', 'True']):
        return 1
    elif any(word in response for word in ['no', 'false', '0', 'zero', 'negative', 'False']):
        return 0
    
    # Default to 0 if unclear
    return 0

def parse_float_prediction(response):
    """Parse float prediction from model response"""
    # Clean the response
    response = response.strip()
    
    # Try to extract float number
    print(response)
    try:
        # Look for decimal numbers
        match = re.search(r'(\d+\.?\d*)', response)
        if match:
            return float(match.group(1))
        
        # Try direct parsing
        return float(response)
    except (ValueError, TypeError):
        pass
    
    # Fallback: return 0 if parsing fails
    return 0.0

def calculate_metrics(results, task_type):
    """Calculate accuracy and F1 scores based on task type"""
    if not results:
        print("Warning: No results to calculate metrics for")
        return {}
    
    if task_type in ['readmission', 'mortality']:
        # Binary classification tasks
        true_labels = [r['true_label'] for r in results]
        predicted_labels = [r['predicted_label'] for r in results]
        
        accuracy = accuracy_score(true_labels, predicted_labels)
        f1 = f1_score(true_labels, predicted_labels, average='binary')
        f1_macro = f1_score(true_labels, predicted_labels, average='macro')
        
        return {
            "accuracy": accuracy,
            "f1_binary": f1,
            "f1_macro": f1_macro,
            "classification_report": classification_report(true_labels, predicted_labels, target_names=['0', '1'])
        }
    
    elif task_type == 'period':
        # Regression task - calculate MAE, RMSE, and accuracy within tolerance
        errors = [r['error'] for r in results]
        mae = sum(errors) / len(errors)
        rmse = (sum(e**2 for e in errors) / len(errors))**0.5
        
        # Accuracy within 1 day tolerance
        tolerance = 1.0
        within_tolerance = sum(1 for e in errors if e <= tolerance)
        accuracy = within_tolerance / len(errors)
        
        return {
            "mae": mae,
            "rmse": rmse,
            "accuracy_within_1day": accuracy,
            "mean_error": mae
        }
    
    return {}

def main(model_name="Qwen/Qwen3-0.6B", json_file_path=None, output_path=None, 
         task_type="readmission", checkpoint_path="./outputs/clinical_checkpoint.json", checkpoint_frequency=10):
    """
    Process clinical dataset and make predictions
    
    Args:
        model_name: Name of the QWEN model to use
        json_file_path: Path to the JSON file containing SFT format data
        output_path: Path to save output results
        task_type: Type of task ('readmission', 'mortality', or 'period')
        checkpoint_path: Path to save checkpoints
        checkpoint_frequency: How often to save checkpoints
    """
    print(f"Loading model: {model_name}")
    
    # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto"
    ).eval()
    
    print(f"Model loaded successfully.")
    
    # Load dataset
    if json_file_path is None:
        raise ValueError("json_file_path must be provided")
    
    with open(json_file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    print(f"Dataset loaded: {len(data)} samples")
    print(f"Task type: {task_type}")
    
    # Generate output filename from input filename if not provided
    if output_path is None:
        input_filename = os.path.splitext(os.path.basename(json_file_path))[0]
        model_name_short = model_name.split('/')[-1]  # Extract just the model name part
        output_path = f"./outputs/{input_filename}_{model_name_short}.json"
        print(f"Output path set to: {output_path}")
    
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Check for checkpoint
    results = []
    start_idx = 0
    
    if os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, 'r') as f:
                checkpoint = json.load(f)
                results = checkpoint["results"]
                start_idx = checkpoint["last_processed_index"] + 1
                print(f"Resuming from checkpoint at index {start_idx}")
        except:
            print("Failed to load checkpoint, starting from beginning")
    
    # Process samples
    progress_bar = tqdm(total=len(data) - start_idx, desc=f"Processing {task_type} samples")
    
    for i in range(start_idx, len(data)):
        try:
            sample = data[i]
            
            # Choose prediction function based on task type
            if task_type == 'readmission':
                result = predict_readmission(sample, model, tokenizer)
            elif task_type == 'mortality':
                result = predict_mortality(sample, model, tokenizer)
            elif task_type == 'period':
                result = predict_period(sample, model, tokenizer)
            else:
                raise ValueError(f"Unknown task type: {task_type}")
            
            results.append(result)
            
            # Clear GPU cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # Save checkpoint
            if (i + 1) % checkpoint_frequency == 0 or i == len(data) - 1:
                checkpoint = {
                    "results": results,
                    "last_processed_index": i,
                    "task_type": task_type,
                    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
                }
                
                save_checkpoint(checkpoint, checkpoint_path)
                print(f"\nCheckpoint saved at index {i}")
                
                # Calculate metrics so far
                metrics = calculate_metrics(results, task_type)
                if 'accuracy' in metrics:
                    print(f"Current accuracy: {metrics['accuracy']:.2%}")
                elif 'accuracy_within_1day' in metrics:
                    print(f"Current accuracy (within 1 day): {metrics['accuracy_within_1day']:.2%}")
            
        except Exception as e:
            print(f"\nError processing sample {i}: {str(e)}")
            
            # Save checkpoint on error
            checkpoint = {
                "results": results,
                "last_processed_index": i - 1,
                "task_type": task_type,
                "error": str(e),
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
            }
            
            save_checkpoint(checkpoint, checkpoint_path)
            print(f"Checkpoint saved due to error")
        
        progress_bar.update(1)
    
    progress_bar.close()
    
    # Calculate final metrics
    metrics = calculate_metrics(results, task_type)
    
    # Save final results
    final_output = {
        "task_type": task_type,
        "model_name": model_name,
        "json_file_path": json_file_path,
        "summary": {
            "total_samples": len(results),
            **metrics
        },
        "results": results
    }
    
    # Save with nice formatting
    save_checkpoint(final_output, output_path)
    
    print(f"\nProcessing completed!")
    print(f"Task: {task_type}")
    print(f"Total samples: {len(results)}")
    
    # Print metrics
    for metric, value in metrics.items():
        if isinstance(value, float):
            if 'accuracy' in metric.lower():
                print(f"{metric}: {value:.2%}")
            else:
                print(f"{metric}: {value:.4f}")
        elif metric == 'classification_report':
            print(f"\nClassification Report:\n{value}")
    
    print(f"Results saved to {output_path}")
    
    # Clean up checkpoint file if processing completed successfully
    if os.path.exists(checkpoint_path):
        try:
            os.remove(checkpoint_path)
            print(f"Removed checkpoint file: {checkpoint_path}")
        except:
            pass
    
    return final_output

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Run clinical inference with QWEN model')
    
    parser.add_argument('--model_name', type=str, default='Qwen/Qwen3-0.6B',
                      help='Name of the QWEN model to use')
    parser.add_argument('--json_file_path', type=str, required=True,
                      help='Path to the JSON file containing SFT format data')
    parser.add_argument('--output_path', type=str, default=None,
                      help='Path to save output results (default: {input_filename}_{modelname}.json)')
    parser.add_argument('--task_type', type=str, choices=['readmission', 'mortality', 'period'], required=True,
                      help='Type of clinical task')
    parser.add_argument('--checkpoint_path', type=str, default='./outputs/clinical_checkpoint.json',
                      help='Path to save checkpoints')
    parser.add_argument('--checkpoint_frequency', type=int, default=10,
                      help='How often to save checkpoints')
    
    args = parser.parse_args()
    
    # Create outputs directory if it doesn't exist
    if args.output_path is not None:
        os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    
    # Run the inference
    main(
        model_name=args.model_name,
        json_file_path=args.json_file_path,
        output_path=args.output_path,
        task_type=args.task_type,
        checkpoint_path=args.checkpoint_path,
        checkpoint_frequency=args.checkpoint_frequency
    )
