import json
import os
import random
from datasets import load_dataset, Dataset, DatasetDict, load_from_disk
from transformers import AutoTokenizer
from typing import Union, Dict, Any
import logging
import numpy as np
from collections import defaultdict
import re

logger = logging.getLogger(__name__)


def extract_second_original_implementation(text: str) -> str:
    """
    Extracts the code block following the second occurrence of 'Original Implementation:'.
    Returns an empty string if not found.
    """
    matches = list(re.finditer(r'Original Implementation:\s*(def .+?)(?=\n\S|\Z)', text, re.DOTALL))
    if len(matches) >= 2:
        return matches[1].group(1).strip()
    return ""

class MutationDataset:
    """A class to process and format datasets for training or evaluation.

    Attributes:
        dataset (DatasetDict): The loaded dataset, stored as a DatasetDict.
        columns (list): List of column names from the first dataset split.
        tokenizer (AutoTokenizer): Tokenizer for preprocessing text data.
        module_name_lst (list): List of split names in the dataset.
    """

    def __init__(self, hf_repo_or_local_dir: str, tokenizer: AutoTokenizer, trainset):
        """Initialize the DatasetProcessor with a dataset and tokenizer.

        Args:
            hf_repo_or_local_dir (str): Path to a local directory or Hugging Face repository.
            tokenizer (AutoTokenizer): Tokenizer instance for text processing.
        """
        # Load dataset from disk or Hugging Face hub
        if os.path.exists(hf_repo_or_local_dir):
            self.dataset = load_from_disk(hf_repo_or_local_dir)
        else:
            self.dataset = load_dataset(hf_repo_or_local_dir)

        # Ensure dataset is a DatasetDict
        if isinstance(self.dataset, DatasetDict):
            self.module_name_lst = list(self.dataset.keys())
        else:
            self.module_name_lst = ['train']
            self.dataset = DatasetDict({'train': self.dataset})

        self.columns = self._dataset_dict_columns()
        self.tokenizer = tokenizer
        self.trainset = trainset
        self._ensure_score_based_dataset()

    def _dataset_dict_columns(self) -> list:
        """Get column names from the first split of the dataset.

        Returns:
            list: List of feature names (columns) in the dataset.
        """
        key = self.module_name_lst[0]
        return list(self.dataset[key].features.keys())

    def _ensure_score_based_dataset(self) -> None:
        """Validate that the dataset has required columns.

        Raises:
            AssertionError: If 'prompt' or 'response' columns are missing.
            ValueError: If 'score' column is missing.
        """
        # First check if the dataset has the required columns
        has_prompt_columns = 'mutator_prompt' in self.columns and 'solver_prompt' in self.columns or 'prompt' in self.columns
        assert has_prompt_columns, "The dataset must have both 'mutator_prompt' and 'solver_prompt' columns or a 'prompt' column."
        assert 'response' in self.columns and 'solutions' in self.columns, "The dataset must have 'response' and 'solutions' columns."

    def to_dpo_dataset(self, dataset: Dataset, margin_threshold: float = 0.2, 
                             input_column: str = "response", **kwargs) -> Dataset:
        """Convert a dataset to DPO format.

        Args:
            dataset (Dataset): The input dataset to convert.
            margin_threshold (float): Minimum score difference for preference pairs.
            input_column (str): Column to use for responses ("response" or "solutions").
            **kwargs: Additional arguments for customization.

        Returns:
            Dataset: A dataset in preference format with columns:
                - prompt: The input prompt
                - chosen: The higher-scored response
                - rejected: The lower-scored response

        Raises:
            ValueError: If dataset has fewer than 2 samples or invalid scores.
            AssertionError: If input_column is not "response" or "solutions".
        """
        assert input_column in ["response", "solutions"], "input_column must be either 'response' or 'solutions'"

        # Collect all responses and scores
        all_responses = []  # List of (prompt, response, score) tuples
        
        for item in dataset:
            if input_column == "response":
                prompt = item["mutator_prompt"]
                response = item["response"]
                score = item["mutator_score"]
                if kwargs["consistency"]:
                    consistency_scores = json.loads(item["consistency_scores"])
                    score += np.mean(consistency_scores)
                if isinstance(score, (int, float)):
                    all_responses.append((prompt, response, float(score), item["task_id"]))
            else:  # solutions
                # extract the ground-truth solution from mutator_prompt
                prompt = item["solver_prompt"]
                solutions = json.loads(item["solutions"])
                scores = json.loads(item["solution_scores"])

                # Extract ground truth from mutator_prompt or dataset engine
                mutator_prompt = item["mutator_prompt"]
                if mutator_prompt is None: # best-of-N baseline
                    task_id = item["task_id"]
                    # Find matching entry in trainset
                    for entry in self.trainset:
                        if entry["task_id"] == task_id:
                            ground_truth = entry["canonical_solution"]
                            ground_truth = "```python\n" + ground_truth + "\n```"
                            # Add ground truth as first solution if not already present
                            if ground_truth not in solutions:
                                solutions.insert(0, ground_truth)
                                scores.insert(0, 1.0)  # Ground truth gets perfect score
                            break
                else:
                    ground_truth = extract_second_original_implementation(mutator_prompt)
                    ground_truth = "```python\n" + ground_truth + "\n```"
                    # Add ground truth as first solution if not already present
                    if ground_truth not in solutions:
                        solutions.insert(0, ground_truth)
                        scores.insert(0, 1.0)  # Ground truth gets perfect score
                
                if kwargs["consistency"]:
                    consistency_scores = json.loads(item["consistency_scores"])
                    for sol, score, consistency_score in zip(solutions, scores, consistency_scores):
                        if isinstance(score, (int, float)):
                            score += consistency_score
                        all_responses.append((prompt, sol, float(score), item["task_id"]))
                else:
                    for sol, score in zip(solutions, scores):
                        if isinstance(score, (int, float)):
                            all_responses.append((prompt, sol, float(score), item["task_id"]))
        
        # Create preference pairs from responses with sufficient score difference
        data = {"prompt": [], "chosen": [], "rejected": [], "task_id": []}
        seen_tuples = set()  # Track seen (prompt, chosen, rejected) tuples
        total_pairs = 0
        filtered_pairs = 0
        
        for i in range(len(all_responses)):
            for j in range(i + 1, len(all_responses)):
                prompt_i, response_i, score_i, task_id_i = all_responses[i]
                prompt_j, response_j, score_j, task_id_j = all_responses[j]
                # Only pair responses from same prompt
                if prompt_i != prompt_j:
                    continue
                    
                score_diff = abs(score_i - score_j)
                if score_diff >= margin_threshold:
                    total_pairs += 1
                    if score_i > score_j:
                        chosen, rejected = response_i, response_j
                        task_id = task_id_i
                        r_chosen, r_rejected = score_i, score_j
                    else:
                        chosen, rejected = response_j, response_i
                        task_id = task_id_j
                        r_chosen, r_rejected = score_j, score_i
                    
                    # Additional filtering for solutions case
                    if input_column == "solutions":
                        if not (r_chosen == 1.0 and r_rejected < 1.0):
                        # if not (r_chosen > 0.6 and r_rejected < 0.6):
                            continue
                    
                    # Check if this tuple has been seen before
                    tuple_key = (prompt_i, chosen, rejected)
                    if tuple_key not in seen_tuples and chosen != rejected:
                        seen_tuples.add(tuple_key)
                        data["prompt"].append(prompt_i)
                        data["chosen"].append(chosen)
                        data["rejected"].append(rejected)
                        data["task_id"].append(task_id)
                    else:
                        filtered_pairs += 1
        
        print(f"DPO Dataset Statistics:")
        print(f"Total possible pairs: {total_pairs}")
        print(f"Filtered pairs (duplicates): {filtered_pairs}")
        print(f"Final pairs: {len(data['prompt'])}")
        return Dataset.from_dict(data)

    def to_ppo_dataset(self, dataset: Dataset, input_column: str = "response", 
                       **kwargs) -> Dataset:
        assert input_column in ["response", "solutions"], "input_column must be either 'response' or 'solutions'"

        data = {"prompt": [], "response": [], "reward": [], "task_id": []}
        seen_tuples = set()  # Track seen (prompt, response) tuples
        total_examples = 0
        filtered_examples = 0
        
        for item in dataset:
            if input_column == "response":
                prompt = item["mutator_prompt"]
                response = item["response"]
                score = item["mutator_score"]
                
                if isinstance(score, (int, float)):
                    total_examples += 1
                    # Create a tuple key for deduplication
                    tuple_key = (prompt, response)
                    if tuple_key not in seen_tuples:
                        seen_tuples.add(tuple_key)
                        data["prompt"].append(prompt)
                        data["response"].append(response)
                        data["reward"].append(float(score))
                        data["task_id"].append(item["task_id"])
                    else:
                        filtered_examples += 1
            else:  # solutions
                prompt = item["solver_prompt"]
                solutions = json.loads(item["solutions"])
                scores = json.loads(item["solution_scores"])
                
                for sol, score in zip(solutions, scores):
                    if isinstance(score, (int, float)):
                        total_examples += 1
                        # Create a tuple key for deduplication
                        tuple_key = (prompt, sol)
                        if tuple_key not in seen_tuples:
                            seen_tuples.add(tuple_key)
                            data["prompt"].append(prompt)
                            data["response"].append(sol)
                            data["reward"].append(float(score))
                            data["task_id"].append(item["task_id"])
                        else:
                            filtered_examples += 1
        
        print(f"PPO Dataset Statistics:")
        print(f"Total examples: {total_examples}")
        print(f"Filtered examples (duplicates): {filtered_examples}")
        print(f"Final examples: {len(data['prompt'])}")
        return Dataset.from_dict(data)

    def to_grpo_dataset(self, dataset: Dataset, input_column: str = "response") -> Dataset:
        assert input_column in ["response", "solutions"], "input_column must be either 'response' or 'solutions'"
        # step 1: collect per-prompt lists
        prompt_map = defaultdict(list)  # prompt -> list of (completion, reward, task_id)
        for item in dataset:
            if input_column == "response":
                prompt = item.get("mutator_prompt", item.get("prompt"))
                completion = item["response"]
                reward = item["mutator_score"]
                task_id = item["task_id"]
                prompt_map[prompt].append((completion, reward, task_id))
            else:
                prompt = item["solver_prompt"]
                solutions = json.loads(item["solutions"])
                scores    = json.loads(item["solution_scores"])
                for sol, score in zip(solutions, scores):
                    prompt_map[prompt].append((sol, score, item["task_id"]))
        # step 2: dedupe & build grouped lists
        grouped = {"prompt": [], "completions": [], "rewards": [], "task_id": []}
        for prompt, examples in prompt_map.items():
            seen = set()
            comps, rewards, tid = [], [], None
            for comp, rew, t in examples:
                if comp not in seen:
                    seen.add(comp)
                    comps.append(comp)
                    rewards.append(rew)
                    tid = t  # assumes all same prompt share one task_id
            grouped["prompt"].append(prompt)
            grouped["completions"].append(comps)
            grouped["rewards"].append(rewards)
            grouped["task_id"].append(tid)
        print(f"GRPO Grouped Dataset: {len(grouped['prompt'])} unique prompts")
        return Dataset.from_dict(grouped)

    @staticmethod
    def split(dataset: Dataset, eval_ratio: float = 0.1) -> DatasetDict:
        """Split a dataset into train and test sets if eval_ratio > 0.

        Args:
            dataset (Dataset): The dataset to split.
            eval_ratio (float): Proportion of data to use for evaluation (0 to 1).

        Returns:
            DatasetDict: A dictionary with 'train' and optionally 'test' splits.
        """
        if eval_ratio > 0:
            return dataset.train_test_split(test_size=eval_ratio, seed=42)
        return DatasetDict({'train': dataset, 'test': dataset})

    def to(self, format: str = "dpo", eval_ratio: float = 0.2, tokenize: bool = False, 
            apply_chat_template: bool = False, hf_repo_or_local_dir: str = None, 
            num_train_examples: int = None, num_val_examples: int = None, system_prompt: str = None, **kwargs) -> DatasetDict:
        """Convert the dataset to a specified format, optionally tokenize, and split it."""
        # Convert to specified format
        if format == "dpo":
            converted_dataset = DatasetDict({
                split: self.to_dpo_dataset(self.dataset[split], **kwargs)
                for split in self.module_name_lst
            })
        elif format == "grpo":
            converted_dataset = DatasetDict({
                split: self.to_grpo_dataset(self.dataset[split], **kwargs)
                for split in self.module_name_lst
            })
        else:
            converted_dataset = self.dataset
        
        if system_prompt:
            def override_system_prompt(example):
                if isinstance(example["prompt"], list) and example["prompt"][0]["role"] == "system":
                    example["prompt"][0]["content"] = system_prompt
                return example
            converted_dataset = converted_dataset.map(override_system_prompt, batched=False)

        # Split if necessary
        if len(converted_dataset.keys()) == 1 and eval_ratio > 0:
            split_name = list(converted_dataset.keys())[0]
            train_test = self.split(converted_dataset[split_name], eval_ratio)
            if num_train_examples:
                train_test['train'] = self._ensure_diverse_sampling(train_test['train'], num_train_examples)
            if num_val_examples:
                train_test['test'] = self._ensure_diverse_sampling(train_test['test'], num_val_examples)
            converted_dataset = DatasetDict({
                'train': train_test['train'],
                'test': train_test.get('test')
            })
            print("len(train_test['train'])", len(train_test['train']))
            print("len(train_test['test'])", len(train_test['test']))

        # Tokenize if flag is set
        if tokenize:
            converted_dataset = converted_dataset.map(
                lambda example: self._tokenize_example(example, apply_chat_template=apply_chat_template),
                batched=False
            )
        
        # BUG
        # Push to hub if specified
        # if hf_repo_or_local_dir:
        #     converted_dataset.push_to_hub(hf_repo_or_local_dir)

        return converted_dataset

    def tokenize(self, prompt: Union[str, list], apply_chat_template: bool = False) -> Dict[str, Any]:
        """Tokenize a single prompt, which can be a string or a list of messages.

        Args:
            prompt (Union[str, list]): The prompt to tokenize.
            apply_chat_template (bool): Whether to apply the chat template if prompt is a list.

        Returns:
            Dict[str, Any]: Tokenized output with 'input_ids' and 'attention_mask'.

        Raises:
            ValueError: If prompt type doesn't match the apply_chat_template setting.
        """
        if apply_chat_template:
            if not isinstance(prompt, list):
                raise ValueError("Prompt must be a list when apply_chat_template is True")
            text = self.tokenizer.apply_chat_template(prompt)
        else:
            if not isinstance(prompt, str):
                raise ValueError("Prompt must be a string when apply_chat_template is False")
            text = prompt
        # Tokenize without padding to keep it memory-efficient; padding can be handled during batching
        tokenized = self.tokenizer(text, truncation=True)
        return tokenized

    def _tokenize_example(self, example: Dict[str, Any], apply_chat_template: bool = False, fields: list = ['prompt']) -> Dict[str, Any]:
        """Tokenize relevant fields in a dataset example.

        Args:
            example (Dict[str, Any]): A single example from the dataset.
            apply_chat_template (bool): Whether to apply the chat template to list fields.

        Returns:
            Dict[str, Any]: The example with tokenized fields added (e.g., 'input_ids_prompt').
        """
        # Fields to tokenize based on common dataset formats
        for field in fields:
            if field in example:
                field_value = example[field]
                tokenized = self.tokenize(field_value, apply_chat_template=apply_chat_template)
                example[f'input_ids_{field}'] = tokenized['input_ids']
                example[f'attention_mask_{field}'] = tokenized['attention_mask']
        return example

    def _ensure_diverse_sampling(self, dataset: Dataset, num_examples: int) -> Dataset:
        """Ensure diverse sampling by selecting one example per task_id before sampling from remaining examples.

        Args:
            dataset (Dataset): The dataset to sample from.
            num_examples (int): Number of examples to select.

        Returns:
            Dataset: A new dataset with diverse sampling applied.
        """
        # Group examples by task_id
        task_id_to_indices = {}
        for idx, example in enumerate(dataset):
            task_id = example['task_id']
            if task_id not in task_id_to_indices:
                task_id_to_indices[task_id] = []
            task_id_to_indices[task_id].append(idx)

        # First, select one example from each task_id
        selected_indices = []
        remaining_indices = []
        
        for task_id, indices in task_id_to_indices.items():
            if indices:
                # Take one example from this task_id
                selected_indices.append(random.choice(indices))
                # Add remaining indices to the pool
                remaining_indices.extend([idx for idx in indices if idx != selected_indices[-1]])

        # If we need more examples, sample from the remaining indices
        if len(selected_indices) < num_examples:
            additional_needed = num_examples - len(selected_indices)
            if remaining_indices:
                additional_indices = random.sample(remaining_indices, min(additional_needed, len(remaining_indices)))
                selected_indices.extend(additional_indices)
        
        final_dataset = dataset.select(selected_indices)
        
        # Print distribution of task_ids
        task_id_counts = {}
        for example in final_dataset:
            task_id = example['task_id']
            task_id_counts[task_id] = task_id_counts.get(task_id, 0) + 1
        
        print("\nTask ID Distribution in Final Selection:")
        print(f"Total unique task_ids: {len(task_id_counts)}")
        print(f"Total examples: {len(selected_indices)}")
        # print("\nTask ID counts:")
        # for task_id, count in sorted(task_id_counts.items(), key=lambda x: x[1], reverse=True):
        #     print(f"Task ID {task_id}: {count} examples")
        
        return final_dataset

