#!/usr/bin/env python3
"""
Hypothesis Composition Case Study
Demonstrates the difference between base model and LoRA-finetuned model
on hypothesis composition tasks.
"""

import os
import sys
import json
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datetime import datetime

# Try to import PEFT for LoRA support
try:
    from peft import PeftModel
    PEFT_AVAILABLE = True
except ImportError:
    PEFT_AVAILABLE = False
    print("Warning: PEFT not installed. LoRA evaluation will not be available.")

# Add paths for imports
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, parent_dir)
sys.path.insert(0, os.path.join(parent_dir, 'Preprocessing', 'sft_data_preparation'))
sys.path.insert(0, os.path.join(parent_dir, 'Preprocessing', 'paper_decomposition'))

# Import from paper_decomposition_utils
from paper_decomposition_utils import extract_answer_content, extract_between_markers

# Import prompts
from prompt_store import instruction_prompts


class HypothesisCompositionCaseStudy:
    def __init__(
        self,
        base_model_path: str,
        lora_path: str = None,
        device: str = "cuda",
        load_in_8bit: bool = False,
        max_new_tokens: int = 4096,
        # Generation parameters
        temperature: float = 0.6,
        top_p: float = 0.9,
        repetition_penalty: float = 1.2
    ):
        """
        Initialize case study with base model and optional LoRA.
        
        Args:
            base_model_path: Path to the base model
            lora_path: Optional path to LoRA checkpoint
            device: Device to use (cuda/cpu)
            load_in_8bit: Whether to use 8-bit quantization
            max_new_tokens: Maximum tokens to generate (default 3000 for complete output)
            temperature: Generation temperature (default 0.6)
            top_p: Top-p sampling parameter (default 0.9)
            repetition_penalty: Penalty for repetition (default 1.2)
        """
        self.device = device if torch.cuda.is_available() else "cpu"
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature
        self.top_p = top_p
        self.repetition_penalty = repetition_penalty
        self.base_model_path = base_model_path
        self.lora_path = lora_path
        
        print(f"Loading base model from {base_model_path}")
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            base_model_path,
            trust_remote_code=True,
            use_fast=False
        )
        
        # Set padding token if not set
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load base model
        if load_in_8bit:
            self.base_model = AutoModelForCausalLM.from_pretrained(
                base_model_path,
                load_in_8bit=True,
                device_map="auto",
                trust_remote_code=True
            )
        else:
            self.base_model = AutoModelForCausalLM.from_pretrained(
                base_model_path,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True
            )
        
        self.base_model.eval()
        print("Base model loaded successfully")
        
        # Load LoRA model if provided
        self.lora_model = None
        if lora_path and PEFT_AVAILABLE:
            print(f"Loading LoRA weights from {lora_path}")
            
            # Load another instance of base model for LoRA
            if load_in_8bit:
                lora_base = AutoModelForCausalLM.from_pretrained(
                    base_model_path,
                    load_in_8bit=True,
                    device_map="auto",
                    trust_remote_code=True
                )
            else:
                lora_base = AutoModelForCausalLM.from_pretrained(
                    base_model_path,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                    trust_remote_code=True
                )
            
            # Apply LoRA weights
            self.lora_model = PeftModel.from_pretrained(
                lora_base,
                lora_path,
                torch_dtype=torch.bfloat16
            )
            
            # Merge for faster inference
            self.lora_model = self.lora_model.merge_and_unload()
            self.lora_model.eval()
            print("LoRA model loaded and merged successfully")

    def generate_response(self, model, prompt: str) -> str:
        """
        Generate a response from the specified model.
        
        The prompt should already be formatted with chat template (including <think>).
        
        Args:
            model: The model to use for generation
            prompt: Formatted prompt (with chat template applied)
            
        Returns:
            Generated response including reasoning and hypothesis
        """
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            max_length=16384,
            truncation=True
        ).to(self.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                temperature=self.temperature,
                do_sample=True,
                top_p=self.top_p,
                repetition_penalty=self.repetition_penalty,
                num_beams=1,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )
        
        # Decode the generated part
        generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
        response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
        
        return response

    def run_case_study(self, case_study_input: dict = None):
        """
        Run a case study comparing base and LoRA models.
        """
        # Default case study if none provided
        if case_study_input is None:
            case_study_input = self.get_default_case_study()
        
        # Extract input components
        research_question = case_study_input["research_question"]
        background_survey = case_study_input["background_survey"]
        previous_hypothesis = case_study_input.get("previous_hypothesis")
        if previous_hypothesis is None or previous_hypothesis == "":
            previous_hypothesis = "No previous hypothesis."
        inspiration_title = case_study_input["inspiration_title"]
        inspiration_abstract = case_study_input["inspiration_abstract"]
        
        # Build prompt using EXACT training format
        # Use the exact same system prompt from training (defined in prompt_store.py)
        
        # Build user content using prepare_HC_sft_data_to_go template
        # gen_prompts = instruction_prompts("prepare_HC_sft_data_to_go")
        # gen_prompts = instruction_prompts("prepare_HC_sft_data_to_go_comprehensive")
        # v2: Use delta hypothesis format (Inspiration/Motivation/Mechanism/Methodology)
        gen_prompts = instruction_prompts("prepare_HC_sft_data_to_go_comprehensive_v2_delta")
        user_content = (
            gen_prompts[0] + research_question + 
            gen_prompts[1] + background_survey + 
            gen_prompts[2] + previous_hypothesis + 
            gen_prompts[3] + inspiration_title + 
            gen_prompts[4] + inspiration_abstract + 
            gen_prompts[5]
        )
        
        # R1-Distill Native Format:
        # - No system prompt (DeepSeek R1-Distill was trained without system prompts)
        # - Training data includes <think>\n at start of assistant content
        # - Use add_generation_prompt=False and manually add <｜Assistant｜>
        # - Model generates: <think>\n[reasoning]\n</think>\n\n[hypothesis]
        messages = [
            {"role": "user", "content": user_content},
        ]
        
        # Use add_generation_prompt=False to avoid adding <think>\n to prompt
        # (model learned to generate <think>\n as first token)
        prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )
        # Manually add <｜Assistant｜> - model will generate <think>\n itself
        prompt += "<｜Assistant｜>"
        
        # Use same prompt for both models for fair comparison
        base_prompt = prompt
        lora_prompt = prompt
        
        print("\n" + "="*80)
        print("HYPOTHESIS COMPOSITION CASE STUDY")
        print("="*80)
        
        print("\n📚 INPUT TEMPLATE:")
        print("-" * 40)
        print(f"Research Question: {research_question[:200]}...")
        print(f"Background Survey: {background_survey[:200]}...")
        print(f"Previous Hypothesis: {previous_hypothesis[:200]}...")
        print(f"Inspiration Title: {inspiration_title}")
        print(f"Inspiration Abstract: {inspiration_abstract[:200]}...")
        
        # Generate with base model
        print("\n" + "="*80)
        print("🤖 BASE MODEL RESPONSE:")
        print("-" * 40)
        base_raw_response = self.generate_response(self.base_model, base_prompt)
        print(f"[DEBUG] Raw response length: {len(base_raw_response)} chars")
        if len(base_raw_response) < 100:
            print(f"[DEBUG] Full raw response: {repr(base_raw_response)}")
        # First try v2 delta format markers
        base_hypothesis = extract_between_markers(base_raw_response, r'Delta\s*Hypothesis')
        if not base_hypothesis:
            # Fallback to existing extraction
            print("[DEBUG] Delta extraction failed, falling back to extract_answer_content...")
            base_hypothesis = extract_answer_content(base_raw_response)
        # If extraction fails, use the raw response
        if not base_hypothesis and base_raw_response:
            print("[DEBUG] extract_answer_content returned empty, using raw response")
            base_hypothesis = base_raw_response
        print("Generated Hypothesis:")
        print(base_hypothesis)
        
        # Generate with LoRA model if available
        if self.lora_model:
            print("\n" + "="*80)
            print("🚀 LORA MODEL RESPONSE:")
            print("-" * 40)
            lora_raw_response = self.generate_response(self.lora_model, lora_prompt)
            print(f"[DEBUG] Raw response length: {len(lora_raw_response)} chars")
            # First try v2 delta format markers
            lora_hypothesis = extract_between_markers(lora_raw_response, r'Delta\s*Hypothesis')
            if not lora_hypothesis:
                # Fallback to existing extraction
                print("[DEBUG] Delta extraction failed, falling back to extract_answer_content...")
                lora_hypothesis = extract_answer_content(lora_raw_response)
            # If extraction fails, use the raw response
            if not lora_hypothesis and lora_raw_response:
                print("[DEBUG] extract_answer_content returned empty, using raw response")
                lora_hypothesis = lora_raw_response
            print("Generated Hypothesis:")
            print(lora_hypothesis)
            
            # Compare outputs
            print("\n" + "="*80)
            print("📊 COMPARISON:")
            print("-" * 40)
            print(f"Base model length: {len(base_hypothesis)} characters")
            print(f"LoRA model length: {len(lora_hypothesis)} characters")
            
            # Simple similarity check (word overlap)
            base_words = set(base_hypothesis.lower().split())
            lora_words = set(lora_hypothesis.lower().split())
            overlap = len(base_words & lora_words)
            total_unique = len(base_words | lora_words)
            similarity = overlap / total_unique if total_unique > 0 else 0
            print(f"Word overlap similarity: {similarity:.2%}")
            
            # Return results
            return {
                "input": case_study_input,
                "base_hypothesis": base_hypothesis,
                "lora_hypothesis": lora_hypothesis,
                "comparison": {
                    "base_length": len(base_hypothesis),
                    "lora_length": len(lora_hypothesis),
                    "word_similarity": similarity
                }
            }
        else:
            print("\n⚠️  No LoRA model loaded - only base model results shown")
            return {
                "input": case_study_input,
                "base_hypothesis": base_hypothesis,
                "lora_hypothesis": None,
                "comparison": None
            }

    def get_default_case_study(self) -> dict:
        """
        Provide a default case study example.
        """
        return {
            "research_question": "How can we develop more efficient methods for training large language models that reduce computational costs while maintaining performance?",
            
            "background_survey": "Current approaches to training large language models require significant computational resources. Methods like mixed precision training, gradient checkpointing, and model parallelism have been proposed to reduce memory usage. Recent work has explored techniques such as LoRA (Low-Rank Adaptation) and QLoRA for parameter-efficient fine-tuning. However, there remains a need for methods that can further reduce training costs without sacrificing model quality.",
            
            "previous_hypothesis": "We hypothesize that by combining selective layer freezing with adaptive learning rate scheduling, we can reduce the computational requirements of LLM training by 30% while maintaining 95% of the original performance.",
            
            "inspiration_title": "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness",
            
            "inspiration_abstract": "Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce the compute complexity, but often do not achieve wall-clock speedup. We argue that a missing principle is making attention algorithms IO-aware -- accounting for reads and writes between levels of GPU memory. We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. We analyze the IO complexity of FlashAttention, showing that it requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes. We also extend FlashAttention to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method. FlashAttention trains Transformers faster than existing baselines: 15% end-to-end wall-clock speedup on BERT-large (seq. length 512) compared to the MLPerf 1.1 training speed record, 3× speedup on GPT-2 (seq. length 1K), and 2.4× speedup on long-range arena (seq. length 1K-4K). FlashAttention and block-sparse FlashAttention enable longer context in Transformers, yielding higher quality models (0.7 better perplexity on GPT-2 and 6.4 points of lift on long-document classification) and entirely new capabilities: the first Transformers to achieve better-than-chance performance on the Path-X challenge (seq. length 16K, 61.4% accuracy) and Path-256 (seq. length 64K, 63.1% accuracy)."
        }

    def run_custom_case_study(self, input_file: str):
        """
        Run case study with custom input from JSON file.
        """
        with open(input_file, 'r') as f:
            case_study_input = json.load(f)
        
        return self.run_case_study(case_study_input)

    def save_results(self, results: dict, output_file: str):
        """
        Save case study results to JSON file.
        """
        # Add metadata
        results["metadata"] = {
            "timestamp": datetime.now().isoformat(),
            "base_model": self.base_model_path,
            "lora_model": self.lora_path,
            "max_new_tokens": self.max_new_tokens
        }
        
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"\n✅ Results saved to: {output_file}")


def main():
    parser = argparse.ArgumentParser(description='Run hypothesis composition case study')
    
    # Model configuration
    parser.add_argument("--base_model_path", type=str, required=True, 
                       help="Path to base model")
    parser.add_argument("--lora_path", type=str, default=None,
                       help="Path to LoRA checkpoint (optional)")
    
    # Generation settings
    parser.add_argument("--load_in_8bit", action="store_true",
                       help="Load model in 8-bit precision")
    parser.add_argument("--max_new_tokens", type=int, default=4096,
                       help="Maximum new tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.6,
                       help="Generation temperature")
    parser.add_argument("--top_p", type=float, default=0.9,
                       help="Top-p sampling parameter")
    parser.add_argument("--repetition_penalty", type=float, default=1.2,
                       help="Repetition penalty for LoRA model")
    
    # Input/Output
    parser.add_argument("--input_file", type=str, default=None,
                       help="JSON file with custom case study input")
    parser.add_argument("--output_file", type=str, default=None,
                       help="JSON file to save results")
    
    args = parser.parse_args()
    
    # Initialize case study
    case_study = HypothesisCompositionCaseStudy(
        base_model_path=args.base_model_path,
        lora_path=args.lora_path,
        device="cuda",
        load_in_8bit=args.load_in_8bit,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        repetition_penalty=args.repetition_penalty
    )
    
    # Run case study
    if args.input_file:
        results = case_study.run_custom_case_study(args.input_file)
    else:
        results = case_study.run_case_study()
    
    # Save results if output file specified
    if args.output_file:
        case_study.save_results(results, args.output_file)


if __name__ == "__main__":
    main()
