"""
Model-Fit metric implementation for evaluating discovered objectives.

This module implements the Model-Fit metric from Methods.md:
Model-Fit(R̂, t) = E[R*(x,y) | π_θ̂_t|R̂] / E[R*(x,y) | π_θ_T|R*]

The metric compares the expected ground-truth rewards between:
1. A policy model at time t trained with estimated rewards (π_θ̂_t|R̂)
2. The final policy model at time T trained with ground-truth rewards (π_θ_T|R*)
"""

import torch
from typing import List, Optional, Tuple
import numpy as np
from tqdm import tqdm
import gc
import re
import os
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

try:
    from .model_generation import generate_responses_batched
    # from .setup_datasets import load_evaluation_dataset, convert_sample_to_prompt
    from .setup_datasets_clean import load_evaluation_dataset, convert_sample_to_prompt
except ImportError:
    from model_generation import generate_responses_batched
    # from setup_datasets import load_evaluation_dataset, convert_sample_to_prompt
    from setup_datasets_clean import load_evaluation_dataset, convert_sample_to_prompt

def supports_flash_attention(device_id):
    """Check if a GPU supports FlashAttention."""
    if not torch.cuda.is_available():
        return False
    major, minor = torch.cuda.get_device_capability(device_id)
    
    # Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
    is_sm8x = major == 8 and minor >= 0
    is_sm90 = major == 9 and minor == 0

    return is_sm8x or is_sm90

print('Flash Attention Supported:', supports_flash_attention(0))

class ModelRewardEvaluator:
    """
    Evaluates the expected reward of a single model on an evaluation dataset.
    Memory-efficient implementation that loads and evaluates one model at a time.
    """
    
    def __init__(
        self,
        reward_function,
        dataset_name: str = 'Anthropic/hh-rlhf',
        dataset_split: str = 'test',
        num_eval_samples: int = 100,
        batch_size: int = 8,
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        load_in_4bit: bool = True,
        torch_dtype: torch.dtype = torch.bfloat16,
        seed: Optional[int] = None,
        multi_turn: bool = False,
        max_prompt_length: int = None,
        base_model_name: str = None,
    ):
        """
        Initialize the evaluator.

        Args:
            reward_function: RewardFunction instance for computing R*(x,y)
            dataset_name: Name of dataset ('Anthropic/hh-rlhf' or 'openai/summarize_from_feedback')
            num_eval_samples: Number of evaluation samples to use
            batch_size: Batch size for generation and reward computation
            max_new_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature for generation
            top_p: Top-p value for nucleus sampling
            device: Device to run computations on
            load_in_4bit: Whether to use 4-bit quantization for memory efficiency
            torch_dtype: Data type for model weights
            seed: Random seed for sampling prompts (None for random)
            max_prompt_length: Optional max token length for filtering prompts
            base_model_name: Model name for tokenizer (required if max_prompt_length is set)
            NOTE: Uses a different set of eval_prompts each time a model is evaluated
        """
        # self.eval_dataset = eval_dataset
        self.reward_function = reward_function
        self.dataset_name = dataset_name
        self.dataset_split = dataset_split
        self.batch_size = batch_size
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature
        self.top_p = top_p
        self.device = device
        self.load_in_4bit = load_in_4bit
        self.torch_dtype = torch_dtype
        self.seed = seed
        self.multi_turn = multi_turn
        
        # Initialize random number generator for reproducible but different prompts
        self.rng = np.random.RandomState(seed) if seed is not None else np.random.RandomState()
        
        # Prompt strings will be set when generating responses
        self.prompt_strings = None

        self.eval_dataset = load_evaluation_dataset(dataset_name, dataset_split, max_prompt_length=max_prompt_length, base_model_name=base_model_name, multi_turn=multi_turn)
        self.num_eval_samples = min(num_eval_samples, len(self.eval_dataset))
    
    def sample_eval_prompts(self) -> List:
        """
        Sample evaluation prompts from the dataset.
        This is a public method to allow pre-sampling prompts for consistency across models.

        Returns:
            List of message lists for chat template formatting
        """
        return self._get_eval_prompts()

    def _clear_memory(self):
        """Clear GPU/CPU memory cache."""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    def _get_eval_prompts(self) -> List:
        """
        Extract evaluation prompts from the dataset based on dataset type.
        Each call returns a different random sample of prompts.
        
        Returns:
            List of message lists for chat template formatting
        """
        # Use the instance's random number generator for different but reproducible samples
        indices = self.rng.choice(
            len(self.eval_dataset), 
            size=self.num_eval_samples, 
            replace=False
        )
        
        prompts = []
        for idx in indices:
            sample = self.eval_dataset[int(idx)]
            converted_prompt = convert_sample_to_prompt(sample, self.dataset_name, self.multi_turn)
            prompts.append(converted_prompt)
        return prompts
            
            # if 'openai/summarize_from_feedback' in self.dataset_name:
            #     # Extract Reddit post information
            #     info = sample.get('info', {})
            #     subreddit = info.get('subreddit', '')
            #     title = info.get('title', '')
            #     post = info.get('post', '')
            #     site = info.get('site', '')
                
            #     # Format as TLDR prompt
            #     query_text = f"SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"

            #     if (subreddit is None) or (site in ['dailymail', 'cnn']):
            #         site = info.get('site', '')
            #         article = info.get('article', '')
            #         query_text = f"SUBREDDIT: r/{site}\n\nTITLE: {title}\n\nPOST: {article}\n\nTL;DR:"

            #     messages = [{"role": "user", "content": query_text}]
            #     prompts.append(messages)
                
            # elif 'Anthropic/hh-rlhf' in self.dataset_name:
            #     # Parse HH-RLHF format - SINGLE-TURN VERSION
            #     # This matches the parsing logic from sft_preprocess_function in setup_datasets.py
            #     # when multi_turn=False (single-turn training)
                
            #     chosen = sample.get('chosen', '')
                
            #     # Parse the full conversation using the EXACT same logic as sft_preprocess_function
            #     messages = []
            #     parts = chosen.split('\n\n')  # Split by double newlines, same as SFT
                
            #     current_role = None
            #     current_content = []
                
            #     for part in parts:
            #         if part.startswith('Human:'):  # Check prefix, same as SFT
            #             # Save previous message if exists
            #             if current_role and current_content:
            #                 messages.append({
            #                     "role": "user" if current_role == "Human" else "assistant",
            #                     "content": ' '.join(current_content).strip()  # Join with spaces, same as SFT
            #                 })
            #             current_role = "Human"
            #             current_content = [part[6:].strip()]  # Remove "Human:" prefix (6 chars), same as SFT
                        
            #         elif part.startswith('Assistant:'):  # Check prefix, same as SFT
            #             # Save previous message if exists
            #             if current_role and current_content:
            #                 messages.append({
            #                     "role": "user" if current_role == "Human" else "assistant",
            #                     "content": ' '.join(current_content).strip()  # Join with spaces, same as SFT
            #                 })
            #             current_role = "Assistant"
            #             current_content = [part[10:].strip()]  # Remove "Assistant:" prefix (10 chars), same as SFT
                        
            #         elif current_content:
            #             # Continuation of previous message, same as SFT
            #             current_content.append(part.strip())
                
            #     # Save the last message if it exists, same as SFT
            #     if current_role and current_content:
            #         messages.append({
            #             "role": "user" if current_role == "Human" else "assistant",
            #             "content": ' '.join(current_content).strip()  # Join with spaces, same as SFT
            #         })
                
            #     # SINGLE-TURN LOGIC: Extract only the first user message as the prompt
            #     # This matches lines 107-133 in sft_preprocess_function when multi_turn=False
            #     prompt_messages = []
                
            #     # Find first user message (matches SFT lines 112-116)
            #     for msg in messages:
            #         if msg["role"] == "user" and not prompt_messages:
            #             prompt_messages.append(msg)
            #             break
                
            #     # Fallback if structure is unexpected (matches SFT lines 127-129)
            #     if not prompt_messages and messages:
            #         if messages[0]["role"] == "user":
            #             prompt_messages = [messages[0]]
                
            #     # Only add if we found a valid user message
            #     if prompt_messages:
            #         prompts.append(prompt_messages)
            # else:
            #     # Fallback for other datasets
            #     if isinstance(sample, dict):
            #         if 'prompt' in sample:
            #             messages = [{"role": "user", "content": sample['prompt']}]
            #         elif 'input' in sample:
            #             messages = [{"role": "user", "content": sample['input']}]
            #         elif 'text' in sample:
            #             messages = [{"role": "user", "content": sample['text']}]
            #         else:
            #             messages = [{"role": "user", "content": str(sample)}]
            #     else:
            #         messages = [{"role": "user", "content": str(sample)}]
            #     prompts.append(messages)
    
    def _load_model(self, model_path: str) -> Tuple:
        """
        Load a model with memory-efficient settings.
        Handles both adapter models and full models.

        Args:
            model_path: Path to the model (either adapter or full model)

        Returns:
            Tuple of (model, tokenizer)
        """
        # Check if this is an adapter model
        adapter_config_path = os.path.join(model_path, 'adapter_config.json')
        is_adapter = os.path.exists(adapter_config_path)

        if is_adapter:
            print(f"Loading adapter model from: {model_path}")

            # Load adapter config to get base model
            with open(adapter_config_path, 'r') as f:
                adapter_config = json.load(f)
            base_model_name = adapter_config.get('base_model_name_or_path')

            # Check if tokenizer files exist in the adapter directory
            tokenizer_config_path = os.path.join(model_path, 'tokenizer_config.json')
            if os.path.exists(tokenizer_config_path):
                tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            else:
                tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)

            # Ensure tokenizer has pad token
            if tokenizer.pad_token is None:
                tokenizer.add_special_tokens({"pad_token": "[PAD]"})

            # Configure quantization for memory efficiency
            if self.load_in_4bit:
                bnb_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=self.torch_dtype,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type='nf4'
                )
                base_model = AutoModelForCausalLM.from_pretrained(
                    base_model_name,
                    quantization_config=bnb_config,
                    device_map=self.device,
                    torch_dtype=self.torch_dtype,
                    trust_remote_code=True,
                    use_safetensors=True,
                    use_cache=False,
                    # attn_implementation='flash_attention_2',
                    attn_implementation='flash_attention_2' if supports_flash_attention(0) else 'sdpa',
                )
            else:
                base_model = AutoModelForCausalLM.from_pretrained(
                    base_model_name,
                    device_map=self.device,
                    torch_dtype=self.torch_dtype,
                    trust_remote_code=True,
                    use_safetensors=True,
                    use_cache=False,
                    # attn_implementation='flash_attention_2',
                    attn_implementation='flash_attention_2' if supports_flash_attention(0) else 'sdpa',
                )

            # Resize embeddings if needed
            if len(tokenizer) != base_model.config.vocab_size:
                print(f"Resizing model embeddings from {base_model.config.vocab_size} to {len(tokenizer)}")
                base_model.resize_token_embeddings(len(tokenizer))

            # Apply adapter
            model = PeftModel.from_pretrained(base_model, model_path)

        else:
            print(f"Loading full model from: {model_path}")

            # Load tokenizer
            tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

            # Configure quantization for memory efficiency
            if self.load_in_4bit:
                bnb_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=self.torch_dtype,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type='nf4'
                )
                model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    quantization_config=bnb_config,
                    device_map=self.device,
                    torch_dtype=self.torch_dtype,
                    trust_remote_code=True,
                    use_safetensors=True,
                    use_cache=False,
                    # attn_implementation='flash_attention_2',
                    attn_implementation='flash_attention_2' if supports_flash_attention(0) else 'sdpa',
                )
            else:
                model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    device_map=self.device,
                    torch_dtype=self.torch_dtype,
                    trust_remote_code=True,
                    use_safetensors=True,
                    use_cache=False,
                    # attn_implementation='flash_attention_2',
                    attn_implementation='flash_attention_2' if supports_flash_attention(0) else 'sdpa',
                )

            # Ensure tokenizer has pad token
            if tokenizer.pad_token is None:
                print("Adding pad token to tokenizer...")
                tokenizer.pad_token = tokenizer.eos_token

            # Resize embeddings if needed
            if len(tokenizer) != model.config.vocab_size:
                model.resize_token_embeddings(len(tokenizer))

        # Set chat template if not present
        if tokenizer.chat_template is None:
            from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
            tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
            print("WARNING: Set chat template to SIMPLE_QUERY_CHAT_TEMPLATE")

        return model, tokenizer
    
    def _generate_responses(
        self,
        # model,
        # tokenizer,
        model_path,
        prompts: List
    ) -> List[str]:
        """
        Generate responses from a model for given prompts.

        Args:
            model: The policy model
            tokenizer: The tokenizer
            prompts: List of message lists for chat template

        Returns:
            List of generated response strings
        """
        model, tokenizer = self._load_model(model_path)

        # Prepare prompt strings for reward computation if not done yet
        if self.prompt_strings is None:
            self.prompt_strings = []
            for messages in prompts:  # Use the passed prompts parameter
                # Extract text content for reward evaluation
                # prompt_text = messages[0]['content']  # Old: Use only the user message for reward
                # Properly format multi-turn conversations using chat template
                if isinstance(messages, list):
                    prompt_text = tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,
                        add_generation_prompt=False
                    )
                else:
                    prompt_text = tokenizer.apply_chat_template(
                        [{"role": "user", "content": messages}],
                        tokenize=False,
                        add_generation_prompt=False
                    )
                self.prompt_strings.append(prompt_text)

        # Use generate_responses_batched from model_generation.py
        responses = generate_responses_batched(
            model_path=model_path,
            model=model,
            tokenizer=tokenizer,
            prompts=prompts,
            max_new_tokens=self.max_new_tokens,
            batch_size=self.batch_size,
            temperature=self.temperature,
            top_p=self.top_p,
            device=self.device
        )

        del model
        self._clear_memory()

        return responses

        # # Original implementation (commented out)
        # model.eval()
        # responses = []
        #
        # # Prepare prompt strings for reward computation if not done yet
        # if self.prompt_strings is None:
        #     self.prompt_strings = []
        #     for messages in prompts:  # Use the passed prompts parameter
        #         # Extract text content for reward evaluation
        #         prompt_text = tokenizer.apply_chat_template(
        #             messages,
        #             tokenize=False,
        #             add_generation_prompt=True
        #         )
        #         self.prompt_strings.append(prompt_text)
        #
        # # Process in batches
        # for i in tqdm(range(0, len(prompts), self.batch_size),
        #              desc="Generating responses"):
        #     batch_messages = prompts[i:i + self.batch_size]
        #
        #     # Apply chat template and tokenize
        #     batch_input_ids = []
        #     for messages in batch_messages:
        #         input_ids = tokenizer.apply_chat_template(
        #             messages,
        #             tokenize=True,
        #             add_generation_prompt=True,
        #             return_tensors="pt"
        #         )
        #         batch_input_ids.append(input_ids.squeeze(0))
        #
        #     # Pad sequences
        #     max_length = max(ids.shape[0] for ids in batch_input_ids)
        #     padded_input_ids = []
        #     attention_masks = []
        #
        #     for ids in batch_input_ids:
        #         padding_length = max_length - ids.shape[0]
        #         if padding_length > 0:
        #             # Left padding for generation
        #             padded_ids = torch.cat([
        #                 torch.full((padding_length,), tokenizer.pad_token_id, dtype=ids.dtype),
        #                 ids
        #             ])
        #             mask = torch.cat([
        #                 torch.zeros(padding_length, dtype=torch.long),
        #                 torch.ones(ids.shape[0], dtype=torch.long)
        #             ])
        #         else:
        #             padded_ids = ids
        #             mask = torch.ones(ids.shape[0], dtype=torch.long)
        #
        #         padded_input_ids.append(padded_ids)
        #         attention_masks.append(mask)
        #
        #     input_ids = torch.stack(padded_input_ids).to(self.device)
        #     attention_mask = torch.stack(attention_masks).to(self.device)
        #
        #     # Generate
        #     with torch.no_grad():
        #         outputs = model.generate(
        #             input_ids,
        #             attention_mask=attention_mask,
        #             max_new_tokens=self.max_new_tokens,
        #             temperature=self.temperature,
        #             top_p=self.top_p,
        #             do_sample=True,
        #             # pad_token_id=tokenizer.pad_token_id,
        #             # eos_token_id=tokenizer.eos_token_id
        #         )
        #
        #     # Extract and decode only the generated part
        #     # After generation, output structure is: [padded_input] + [generated_tokens]
        #     # We need to skip the entire padded input length, not just the non-padded tokens
        #     input_length = input_ids.shape[1]  # Length of padded input for all sequences in batch
        #
        #     for j, output in enumerate(outputs):
        #         # Skip the padded input portion to get only generated tokens
        #         generated_ids = output[input_length:]
        #
        #         # Decode response
        #         response = tokenizer.decode(
        #             generated_ids,
        #             skip_special_tokens=True,
        #             clean_up_tokenization_spaces=True
        #         )
        #         responses.append(response)
        #
        # return responses
    
    def _compute_rewards(
        self,
        responses: List[str]
    ) -> np.ndarray:
        """
        Compute ground-truth rewards for responses.
        
        Args:
            responses: List of generated responses
            
        Returns:
            Array of reward values
        """
        rewards = []
        
        # Process in batches
        for i in range(0, len(responses), self.batch_size):
            batch_prompts = self.prompt_strings[i:i + self.batch_size]
            batch_responses = responses[i:i + self.batch_size]
            
            # Compute rewards
            normed_rewards, denormed_rewards, _ = self.reward_function.compute_reward(
                batch_prompts, 
                batch_responses
            )
            batch_rewards = normed_rewards
            
            # Convert to numpy if needed
            if isinstance(batch_rewards, torch.Tensor):
                batch_rewards = batch_rewards.cpu().float().numpy().tolist()
            
            rewards.extend(batch_rewards)
        
        return np.array(rewards)
    
    def evaluate_model(self, model_path: str, eval_prompts: Optional[List] = None) -> float:
        """
        Evaluate a single model and return its expected reward.

        Args:
            model_path: Path to the model to evaluate
            eval_prompts: Optional pre-sampled prompts to use. If None, samples new prompts.

        Returns:
            Expected reward (average over evaluation samples)
        """
        print(f"\nEvaluating model: {model_path}")

        # Use provided prompts or sample new ones
        if eval_prompts is None:
            # Get a fresh random set of evaluation prompts for this model
            eval_prompts = self._get_eval_prompts()
        else:
            print(f"Using pre-sampled {len(eval_prompts)} prompts")
        
        # Reset prompt strings for this evaluation
        self.prompt_strings = None
        
        # Load model
        # print("Loading model...")
        # model, tokenizer = self._load_model(model_path)
        
        # Generate responses
        print("Generating responses...")
        # responses = self._generate_responses(model, tokenizer, eval_prompts)
        responses = self._generate_responses(model_path, eval_prompts)
        
        # Free model memory
        # del model
        # self._clear_memory()
        
        # Compute rewards
        print("Computing rewards...")
        rewards = self._compute_rewards(responses)
        
        # Calculate average reward
        avg_reward = float(np.mean(rewards))
        
        print(f"Average reward: {avg_reward:.4f}")
        
        return avg_reward