"""
Utilities for creating and manipulating navigation datasets.
"""

import torch
import json
import os
import gzip
import logging
from typing import Optional, List, Dict, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import re

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class MixtralInstructionRewriter:
    def __init__(self, model_name: str = "mistralai/Mixtral-8x7B-Instruct-v0.1"):
        """
        Initialize the instruction rewriter with Mixtral model.
        
        Args:
            model_name (str): Name of the model to use for rewriting instructions
        """
        logger.info(f"Loading Mixtral model {model_name}...")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16,
                device_map="auto",  # Let it automatically choose the best device
                load_in_4bit=True   # Use 4-bit quantization to reduce memory usage
            )
            logger.info("Mixtral model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load Mixtral model: {str(e)}")
            raise
        
    def rewrite(self, instruction: str, style: str) -> str:
        """
        Rewrite an instruction into a specified style using Mixtral.
        
        Args:
            instruction (str): The original instruction to rewrite
            style (str): The style to rewrite the instruction in
            
        Returns:
            str: The rewritten instruction
        """
        # Validate style
        valid_styles = ["novice", "expert", "formal", "friendly", "cot"]
        if style.lower() not in valid_styles:
            raise ValueError(f"Invalid style. Must be one of: {valid_styles}")
            
        # Validate instruction
        if not instruction or not isinstance(instruction, str):
            raise ValueError("Instruction must be a non-empty string")
        
        # Format message in Mixtral's preferred format
        messages = [
            {"role": "system", "content": """You are an expert at rewriting navigation instructions in different styles while preserving exact meaning.
Output ONLY the rewritten instruction - no explanations, greetings, or commentary.
Your output MUST be wrapped in <instruction> tags.

CRITICAL RULES:
1. NEVER add information not present in the original instruction
2. NEVER remove or alter any spatial information
3. NEVER change distances, directions, or landmark references
4. NEVER add greetings or commentary
5. NEVER use phrases like "Here's" or "I'll"
6. Start DIRECTLY with the instruction itself"""},
            {"role": "user", "content": f"""Rewrite this navigation instruction in {style} style, preserving ALL spatial details exactly.

Here are the styles and use it to rewrite the instruction and pick one based {style} match:
1. Novice: Explain the navigation instruction in a way that is easy to understand and explaining to a kid.
2. Expert: Explain the navigation instruction in a way that is detailed and technical and explaining to an adult.
3. Formal: Explain the navigation instruction in a way that is formal and professional and explaining to a corporate person.
4. Friendly: Explain the navigation instruction in a way that is friendly and engaging and explaining to a friend casually.
5. CoT (Chain of Thoughts): Break down the navigation instruction into clear, numbered steps (Step 1:, Step 2:, etc.). Each step should represent ONE distinct action or observation. Make sure to preserve ALL spatial details and landmarks exactly.

The navigation instruction is: {instruction}

YOUR RESPONSE MUST BE IN THIS EXACT FORMAT:
<instruction>instruction text here</instruction>"""}
        ]

        try:
            # Apply chat template
            chat_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            
            # Convert to tokens
            inputs = self.tokenizer(chat_text, return_tensors="pt").to(self.model.device)
            
            # Generate response
            with torch.inference_mode():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_new_tokens=200,
                    temperature=0.1,
                    top_p=1.0,
                    repetition_penalty=1.1,
                    pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
                )
            
            # Extract only the generated response, not the prompt
            generated_text = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
            
            # Log the raw response for debugging
            logger.debug(f"Raw Mixtral response (first 100 chars): {generated_text[:100]}...")
            
            # Extract and clean the instruction
            instruction_text = self.extract_instruction(generated_text)
            
            if not instruction_text:
                logger.warning(f"Failed to extract instruction for style '{style}'")
                # Try an alternative extraction approach for Mixtral if the regular extraction failed
                if "<instruction>" not in generated_text.lower() and "</instruction>" not in generated_text.lower():
                    # If Mixtral didn't use the tags, the whole response might be the instruction
                    # Check if it's not too long and doesn't contain obvious chat artifacts
                    if len(generated_text) < 500 and not any(phrase in generated_text.lower() for phrase in [
                        "as an ai", "i'd be happy", "i cannot", "i'm not able", "here's", "is there"
                    ]):
                        instruction_text = generated_text.strip()
                        logger.info(f"Recovered instruction without tags: {instruction_text[:50]}...")
                
                # If still nothing, return original
                if not instruction_text:
                    logger.warning(f"Still failed to extract instruction, returning original for style '{style}'")
                    return instruction
                
            return instruction_text
            
        except Exception as e:
            logger.error(f"Error during instruction rewriting with Mixtral: {str(e)}")
            return instruction  # Return original instruction as fallback

    def extract_instruction(self, full_response: str) -> str:
        """
        Extract the instruction text from the model's response.
        Only extracts content within <instruction> tags.
        
        Args:
            full_response (str): The full response from the model
            
        Returns:
            str: The extracted instruction text or original text if tags not found
        """
        # Only extract content between <instruction> tags
        pattern = r"<instruction>(.*?)</instruction>"
        match = re.search(pattern, full_response, re.DOTALL)
        if match:
            return match.group(1).strip()
            
        # Try more lenient pattern for slightly misspelled tags
        pattern = r"<instruct.*?>(.*?)</instruct.*?>"
        match = re.search(pattern, full_response, re.DOTALL)
        if match:
            return match.group(1).strip()
        
        # Try to find text between backticks (Mixtral sometimes uses these instead)
        pattern = r"`(.*?)`"
        match = re.search(pattern, full_response, re.DOTALL)
        if match and len(match.group(1).strip()) > 0.7 * len(full_response):
            return match.group(1).strip()    
            
        # Check for CoT style patterns, often include "Step 1:", "Step 2:", etc.
        cot_pattern = r"(Step \d+:.*(?:\n.*)*)"
        if re.search(cot_pattern, full_response, re.IGNORECASE | re.MULTILINE):
            # Extract all steps
            steps = re.findall(r"Step \d+:.*(?:\n.*)*?(?=Step \d+:|$)", full_response, re.IGNORECASE | re.MULTILINE)
            if steps:
                # Join all steps with newlines
                return "\n".join(step.strip() for step in steps)
                
        # Check for numbered list patterns (1., 2., etc.)
        numbered_list_pattern = r"(?:\d+\.\s+.*\n?)+"
        match = re.search(numbered_list_pattern, full_response, re.MULTILINE)
        if match and len(match.group(0)) > 0.5 * len(full_response):
            return match.group(0).strip()
            
        # If response is short and clean, return it directly
        if len(full_response) < 300 and "<" not in full_response and ">" not in full_response:
            return full_response.strip()
            
        # Clean up common prefixes and suffixes
        prefixes = [
            "Here's the rewritten instruction:",
            "Here is the rewritten instruction:",
            "Rewritten instruction:",
            "Sure!", 
            "Sure! ", 
            "Sure, ", 
            "Here's ", 
            "Here is ",
            "I'll rewrite ",
            "The rewritten ",
            "I've rewritten ",
            "The navigation instruction in",
        ]
        
        cleaned_response = full_response
        for prefix in prefixes:
            if cleaned_response.lower().startswith(prefix.lower()):
                cleaned_response = cleaned_response[len(prefix):].strip()
                
        # Remove quotes if present
        if (cleaned_response.startswith('"') and cleaned_response.endswith('"')) or \
           (cleaned_response.startswith("'") and cleaned_response.endswith("'")):
            cleaned_response = cleaned_response[1:-1].strip()
        
        # Look for a strong pattern of instruction starts in Mixtral outputs
        style_indicators = [
            "walk ", "go ", "exit ", "enter ", "turn ", "move ", "proceed ", "head ", "continue ",
            "leave ", "make ", "pass ", "navigate ", "follow ", "take ", "open ", "step "
        ]
        
        for line in cleaned_response.lower().split('\n'):
            line = line.strip()
            if line and any(line.startswith(indicator) for indicator in style_indicators):
                # Found what looks like the start of an instruction
                idx = cleaned_response.lower().find(line)
                if idx > 0:
                    cleaned_response = cleaned_response[idx:].strip()
                    break
            
        return cleaned_response

class InstructionLLAMA2_Rewriter:
    def __init__(self, model_name: str = "meta-llama/Llama-2-13b-chat-hf"):
        """
        
        Initialize the instruction rewriter with a specified model.
        
        Args:
            model_name (str): Name of the model to use for rewriting instructions
        """
        logger.info(f"Loading model {model_name}...")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16,
                device_map="auto",  # Let it automatically choose the best device
                load_in_4bit=True   # Ensure 4-bit quantization for memory efficiency
            )
            logger.info("Model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise
        
    def rewrite(self, instruction: str, style: str) -> str:
        """
        Rewrite an instruction into a specified style.
        
        Args:
            instruction (str): The original instruction to rewrite
            style (str): The style to rewrite the instruction in
            
        Returns:
            str: The rewritten instruction
            
        Raises:
            ValueError: If the style is not one of the supported styles
        """
        # Validate style
        valid_styles = ["novice", "expert", "formal", "friendly"]
        if style.lower() not in valid_styles:
            raise ValueError(f"Invalid style. Must be one of: {valid_styles}")
            
        # Validate instruction
        if not instruction or not isinstance(instruction, str):
            raise ValueError("Instruction must be a non-empty string")
            
        chat_prompt = f"""[INST] <<SYS>>
        You are an expert at rewriting navigation instructions in different styles while preserving exact meaning.
        The expert will be given a navigation instruction to VLN model to follow.
        Output ONLY the rewritten instruction - no explanations, greetings, or commentary.
        Your output MUST be wrapped in <instruction> tags.
        
        CRITICAL RULES:
        1. NEVER add information not present in the original instruction
        2. NEVER remove or alter any spatial information
        3. NEVER change distances, directions, or landmark references
        4. NEVER add greetings or commentary
        5. NEVER use phrases like "Here's" or "I'll"
        6. Start DIRECTLY with the instruction itself
        <</SYS>>

        Rewrite this navigation instruction in {style} style, preserving ALL spatial details exactly.

        Here are the styles and use it to rewrite the instruction and pick one based {style} match:
        1. Novice: Explain the navigation instruction in a way that is easy to understand and explaining to a kid.
        2. Expert: Explain the navigation instruction in a way that is detailed and technical and explaining to an adult.
        3. Formal: Explain the navigation instruction in a way that is formal and professional and explaining to a corporate person.
        4. Friendly: Explain the navigation instruction in a way that is friendly and engaging and explaining to a friend casually.
        
        The navigation instruction is {instruction} :
        YOUR RESPONSE MUST BE IN THIS EXACT FORMAT:
        <instruction>instruction text here</instruction>
        [/INST]"""

        try:
            inputs = self.tokenizer(chat_prompt, return_tensors="pt").to(self.model.device)
            
            with torch.inference_mode():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_new_tokens=200,
                    temperature=0.1,
                    top_p=1.0,
                    repetition_penalty=1.1,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
            
            full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            instruction_text = self.extract_instruction(full_response[len(chat_prompt):].strip())
            
            if not instruction_text:
                logger.warning(f"Failed to extract instruction for style '{style}'")
                return instruction  # Return original instruction as fallback
                
            return instruction_text
            
        except Exception as e:
            logger.error(f"Error during instruction rewriting: {str(e)}")
            return instruction  # Return original instruction as fallback

    def extract_instruction(self, full_response: str) -> str:
        """
        Extract the instruction text from the model's response.
        
        Args:
            full_response (str): The full response from the model
            
        Returns:
            str: The extracted instruction text
        """
        # Try exact match first
        pattern = r"<instruction>(.*?)</instruction>"
        match = re.search(pattern, full_response, re.DOTALL)
        if match:
            return match.group(1).strip()
            
        # Try more lenient pattern for slightly misspelled tags
        pattern = r"<instruct.*?>(.*?)</instruct.*?>"
        match = re.search(pattern, full_response, re.DOTALL)
        if match:
            return match.group(1).strip()
            
        # If response is short and clean, return it directly
        if len(full_response) < 300 and "<" not in full_response and ">" not in full_response:
            return full_response.strip()
            
        # Clean up common prefixes and suffixes
        prefixes = [
            "Here's the rewritten instruction:",
            "Here is the rewritten instruction:",
            "Rewritten instruction:",
            "Sure!", 
            "Sure! ", 
            "Sure, ", 
            "Here's ", 
            "Here is ",
            "I'll rewrite ",
            "The rewritten ",
        ]
        
        cleaned_response = full_response
        for prefix in prefixes:
            if cleaned_response.startswith(prefix):
                cleaned_response = cleaned_response[len(prefix):].strip()
                
        # Remove quotes if present
        if (cleaned_response.startswith('"') and cleaned_response.endswith('"')) or \
           (cleaned_response.startswith("'") and cleaned_response.endswith("'")):
            cleaned_response = cleaned_response[1:-1].strip()
            
        return cleaned_response

class Llama3InstructionRewriter:
    def __init__(self, model_name: str = "meta-llama/Llama-3.1-8B-Instruct"):
        """
        Initialize the instruction rewriter with Llama-3-8B model.
        
        Args:
            model_name (str): Name of the model to use for rewriting instructions
        """
        logger.info(f"Loading Llama-3 model {model_name}...")
        try:
            # Set up CUDA memory management for better performance
            if torch.cuda.is_available():
                # Free up any cached memory
                torch.cuda.empty_cache()
                
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16,
                device_map={"": 7},  # Explicitly use GPU 2
                load_in_4bit=True,
                attn_implementation="flash_attention_2",
            )
            logger.info("Llama-3 model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load Llama-3 model: {str(e)}")
            raise
        
    def rewrite(self, instruction: str, style: str) -> str:
        """
        Rewrite an instruction into a specified style using Llama-3.
        
        Args:
            instruction (str): The original instruction to rewrite
            style (str): The style to rewrite the instruction in
            
        Returns:
            str: The rewritten instruction
        """
        # Validate style
        valid_styles = ["novice", "expert", "formal", "friendly", "cot"]
        if style.lower() not in valid_styles:
            raise ValueError(f"Invalid style. Must be one of: {valid_styles}")
            
        # Validate instruction
        if not instruction or not isinstance(instruction, str):
            raise ValueError("Instruction must be a non-empty string")
        
        # Format message in Llama-3's preferred format
        messages = [
            {"role": "system", "content": """You are an expert at rewriting navigation instructions in different styles while preserving exact meaning.
Output ONLY the rewritten instruction - no explanations, greetings, or commentary.
Your output MUST be wrapped in <instruction> tags.

CRITICAL RULES:
1. NEVER add information not present in the original instruction
2. NEVER remove or alter any spatial information
3. NEVER change distances, directions, or landmark references
4. NEVER add greetings or commentary
5. NEVER use phrases like "Here's" or "I'll"
6. Start DIRECTLY with the instruction itself"""},
            {"role": "user", "content": f"""Rewrite this navigation instruction in {style} style, preserving ALL spatial details exactly.

Here are the styles and use it to rewrite the instruction and pick one based {style} match:
1. Novice: Explain the navigation instruction in a way that is easy to understand and explaining to a kid.
2. Expert: Explain the navigation instruction in a way that is detailed and technical and explaining to an adult.
3. Formal: Explain the navigation instruction in a way that is formal and professional and explaining to a corporate person.
4. Friendly: Explain the navigation instruction in a way that is friendly and engaging and explaining to a friend casually.
5. CoT (Chain of Thoughts): Break down the navigation instruction into clear, numbered steps (Step 1:, Step 2:, etc.). Each step should represent ONE distinct action or observation. Make sure to preserve ALL spatial details and landmarks exactly.

The navigation instruction is: {instruction}

YOUR RESPONSE MUST BE IN THIS EXACT FORMAT:
<instruction>instruction text here</instruction>"""}
        ]

        try:
            # Apply chat template
            chat_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            
            # Convert to tokens
            inputs = self.tokenizer(chat_text, return_tensors="pt").to(self.model.device)
            
            # Generate response
            with torch.inference_mode():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_new_tokens=500,
                    temperature=0.1,
                    top_p=0.9,
                    repetition_penalty=1.1,
                    pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
                )
            
            # Extract only the generated response, not the prompt
            generated_text = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
            print(generated_text)
            # Log the raw response for debugging
            logger.debug(f"Raw Llama-3 response (first 100 chars): {generated_text[:100]}...")
            
            # Extract and clean the instruction
            instruction_text = self.extract_instruction(generated_text)
            
            if not instruction_text:
                logger.warning(f"Failed to extract instruction for style '{style}'")
                # If extraction failed, try a fallback approach
                if "<instruction>" not in generated_text.lower() and "</instruction>" not in generated_text.lower():
                    # If model didn't use the tags, the whole response might be the instruction
                    # Check if it's not too long and doesn't contain obvious chat artifacts
                    if len(generated_text) < 500 and not any(phrase in generated_text.lower() for phrase in [
                        "as an ai", "i'd be happy", "i cannot", "i'm not able", "here's", "is there"
                    ]):
                        instruction_text = generated_text.strip()
                        logger.info(f"Recovered instruction without tags: {instruction_text[:50]}...")
                
                # If still nothing, return original
                if not instruction_text:
                    logger.warning(f"Still failed to extract instruction, returning original for style '{style}'")
                    return instruction
                
            return instruction_text
            
        except Exception as e:
            logger.error(f"Error during instruction rewriting with Llama-3: {str(e)}")
            return instruction  # Return original instruction as fallback

    def extract_instruction(self, full_response: str) -> str:
        # Only extract content between <instruction> tags
        pattern = r"<instruction>(.*?)</instruction>"
        match = re.search(pattern, full_response, re.DOTALL)
        if match:
            return match.group(1).strip()
            
        return full_response.strip()
        
    def extract_instruction_2(self, full_response: str) -> str:
        """
        Extract the instruction text from the model's response.
        Only extracts content within <instruction> tags.
        
        Args:
            full_response (str): The full response from the model
            
        Returns:
            str: The extracted instruction text or original text if tags not found
        """
        # Only extract content between <instruction> tags
        pattern = r"<instruction>(.*?)</instruction>"
        match = re.search(pattern, full_response, re.DOTALL)
        if match:
            return match.group(1).strip()
            
        # Try more lenient pattern for slightly misspelled tags
        pattern = r"<instruct.*?>(.*?)</instruct.*?>"
        match = re.search(pattern, full_response, re.DOTALL)
        if match:
            return match.group(1).strip()
        
        # Try to find text between backticks
        pattern = r"`(.*?)`"
        match = re.search(pattern, full_response, re.DOTALL)
        if match and len(match.group(1).strip()) > 0.7 * len(full_response):
            return match.group(1).strip()    
        
        # Check for CoT style patterns, often include "Step 1:", "Step 2:", etc.
        cot_pattern = r"(Step \d+:.*(?:\n.*)*)"
        if re.search(cot_pattern, full_response, re.IGNORECASE | re.MULTILINE):
            # Extract all steps
            steps = re.findall(r"Step \d+:.*(?:\n.*)*?(?=Step \d+:|$)", full_response, re.IGNORECASE | re.MULTILINE)
            if steps:
                # Join all steps with newlines
                return "\n".join(step.strip() for step in steps)
                
        # Check for numbered list patterns (1., 2., etc.)
        numbered_list_pattern = r"(?:\d+\.\s+.*\n?)+"
        match = re.search(numbered_list_pattern, full_response, re.MULTILINE)
        if match and len(match.group(0)) > 0.5 * len(full_response):
            return match.group(0).strip()
            
        # If response is short and clean, return it directly
        if len(full_response) < 300 and "<" not in full_response and ">" not in full_response:
            return full_response.strip()
            
        # Clean up common prefixes and suffixes
        prefixes = [
            "Here's the rewritten instruction:",
            "Here is the rewritten instruction:",
            "Rewritten instruction:",
            "Sure!", 
            "Sure! ", 
            "Sure, ", 
            "Here's ", 
            "Here is ",
            "I'll rewrite ",
            "The rewritten ",
            "I've rewritten ",
            "The navigation instruction in",
        ]
        
        cleaned_response = full_response
        for prefix in prefixes:
            if cleaned_response.lower().startswith(prefix.lower()):
                cleaned_response = cleaned_response[len(prefix):].strip()
                
        # Remove quotes if present
        if (cleaned_response.startswith('"') and cleaned_response.endswith('"')) or \
           (cleaned_response.startswith("'") and cleaned_response.endswith("'")):
            cleaned_response = cleaned_response[1:-1].strip()
        
        # Look for a strong pattern of instruction starts in outputs
        style_indicators = [
            "walk ", "go ", "exit ", "enter ", "turn ", "move ", "proceed ", "head ", "continue ",
            "leave ", "make ", "pass ", "navigate ", "follow ", "take ", "open ", "step "
        ]
        
        for line in cleaned_response.lower().split('\n'):
            line = line.strip()
            if line and any(line.startswith(indicator) for indicator in style_indicators):
                # Found what looks like the start of an instruction
                idx = cleaned_response.lower().find(line)
                if idx > 0:
                    cleaned_response = cleaned_response[idx:].strip()
                    break
            
        return cleaned_response

def compare_models(instruction: str, styles: List[str] = ["novice", "expert", "formal", "friendly", "cot"]):
    """
    Compare the output of different instruction rewriter models on the same input.
    
    Args:
        instruction (str): The instruction to rewrite
        styles (List[str]): The styles to compare
        
    Returns:
        Dict: Dictionary with model outputs for each style
    """
    results = {}
    
    # Initialize models
    logger.info("Initializing models for comparison...")
    models = {
        "Llama-3-8B": Llama3InstructionRewriter(),
        "Mixtral-8x7B": MixtralInstructionRewriter()
    }
    
    # Generate outputs for each model and style
    for model_name, model in models.items():
        results[model_name] = {}
        for style in styles:
            logger.info(f"Generating {style} instruction with {model_name}...")
            results[model_name][style] = model.rewrite(instruction, style)
    
    # Print comparison
    print(f"\nOriginal: {instruction}")
    for style in styles:
        print(f"\n== {style.upper()} ==")
        for model_name in results:
            print(f"{model_name}: {results[model_name][style]}")
    
    return results

def extract_style_from_preformatted(text, style):
    """
    Extract a specific style from pre-formatted text with style tags.
    
    Args:
        text (str): The pre-formatted text with style tags
        style (str): The style to extract (novice, expert, formal, friendly)
        
    Returns:
        str: The extracted style text, or empty string if not found
    """
    pattern = f"<{style}>(.*?)</{style}>"
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return ""

def extract_all_styles(text):
    """
    Extract all styles from pre-formatted text with style tags.
    
    Args:
        text (str): The pre-formatted text with style tags
        
    Returns:
        dict: Dictionary with all extracted styles
    """
    styles = ["novice", "expert", "formal", "friendly", "cot"]
    results = {}
    
    for style in styles:
        extracted = extract_style_from_preformatted(text, style)
        if extracted:
            results[style] = extracted
    
    return results

def is_preformatted(text):
    """
    Check if text is already pre-formatted with style tags.
    
    Args:
        text (str): The text to check
        
    Returns:
        bool: True if text contains at least two style tags
    """
    style_tags = 0
    for style in ["novice", "expert", "formal", "friendly", "cot"]:
        if f"<{style}>" in text and f"</{style}>" in text:
            style_tags += 1
    
    return style_tags >= 2  # At least two different styles should be present

if __name__ == "__main__":
    # Define input and output paths
    input_path = "./val_unseen_min_100.json.gz"
    output_path = "./val_unseen_varied_min_100_cot.json"
    output_readable_path = "./val_unseen_readable_cot.json"  # New readable output
    
    # Define styles
    styles = ["novice", "expert", "formal", "friendly", "cot"]
    
    try:
        # Read the gzipped JSON file
        with gzip.open(input_path, 'rt', encoding='utf-8') as f:
            data = json.load(f)
        
        # Initialize model once - using the Llama3 rewriter
        rewriter = Llama3InstructionRewriter()
        
        # Extract original instructions and create variations
        variations = {style: [] for style in styles}
        original_instructions = []
        readable_data = []  # A more readable format

        # Limit to first 10 episodes for testing
        limited_episodes = data['episodes'][:10]
        
        logger.info("Processing instructions with Llama 3...")
        for i, episode in enumerate(tqdm(limited_episodes)):
            instruction = episode['instruction']['instruction_text']
            original_instructions.append(instruction)
            
            # For readable format
            episode_data = {
                "episode_id": i,
                "original": instruction,
                "styles": {}
            }
            
            # Create variations for each style
            for style in styles:
                rewritten = rewriter.rewrite(instruction, style)
                episode['instruction']['instruction_text_{}'.format(style)] = rewritten
                variations[style].append(rewritten)
                
                # Add to readable format
                episode_data["styles"][style] = rewritten
            
            readable_data.append(episode_data)
        
        # Save the processed data (just the 10 episodes)
        output_data = {
            'episodes': limited_episodes
        }
        
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # Save in original format
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False)
        
        # Save in more readable format
        with open(output_readable_path, 'w', encoding='utf-8') as f:
            json.dump(readable_data, f, indent=2, ensure_ascii=False)
        
        logger.info(f"Successfully processed 10 episodes and saved to {output_path}")
        logger.info(f"Also saved readable format to {output_readable_path}")
        
        # Print example of CoT instruction to verify
        print("\nExample of CoT instruction:")
        if readable_data and "cot" in readable_data[0]["styles"]:
            print(readable_data[0]["styles"]["cot"])
        
    except Exception as e:
        logger.error(f"Error during processing: {str(e)}")
        raise 