
from typing import List, Any, Tuple
from src.scoring.classifier import ClassifierScoringMethod
import numpy as np
import random
import math

class FuzzingScoringMethod(ClassifierScoringMethod):
    def __init__(
        self,
        context_tokenizer: Any,
        max_example_length: int = 50,
        correct_pos_ratio: float = 0.5,
        incorrect_pos_ratio: float = 0.25,
        incorrect_neg_ratio: float = 0.25
    ):
        super().__init__(context_tokenizer, max_example_length)

        assert correct_pos_ratio + incorrect_pos_ratio + incorrect_neg_ratio == 1, "Sum of ratios should be 1"
        self.correct_pos_ratio = correct_pos_ratio  # Proportion of incorrect positives in incorrect examples
        self.incorrect_pos_ratio = incorrect_pos_ratio
        self.incorrect_neg_ratio = incorrect_neg_ratio

    @property
    def system_prompt(self) -> str:
        return """You are an intelligent and meticulous linguistics researcher.

You will be given a certain latent of text, such as "male pronouns" or "text with negative sentiment". You will be given a few examples of text that contain this latent. Portions of the sentence which strongly represent this latent are between tokens << and >>.

Some examples might be mislabeled. Your task is to determine if every single token within << and >> is correctly labeled. Consider that all provided examples could be correct, none of the examples could be correct, or a mix. An example is only correct if every marked token is representative of the latent

For each example in turn, return 1 if the sentence is correctly labeled or 0 if the tokens are mislabeled. You must return your response in a valid Python list of the same length as the number of presented examples. Do not return anything else besides a Python list.
"""

        

    def prepare_examples(
        self, examples: List[Any]
    ) -> Tuple[List[str], List[bool]]:
        positives = [ex for ex in examples if ex.is_positive]
        negatives = [ex for ex in examples if not ex.is_positive]

        R_cp = self.correct_pos_ratio
        R_ip = self.incorrect_pos_ratio
        R_ing = self.incorrect_neg_ratio

        # Calculate maximum possible N for each category
        max_n_cp = len(positives) / R_cp if R_cp > 0 else float('inf')
        max_n_ip = len(positives) / R_ip if R_ip > 0 else float('inf')
        max_n_ing = len(negatives) / R_ing if R_ing > 0 else float('inf')

        max_n = min(max_n_cp, max_n_ip, max_n_ing)
        n = math.floor(max_n)

        # Calculate counts for each category
        count_cp = int(round(R_cp * n))
        count_ip = int(round(R_ip * n))
        count_ing = int(round(R_ing * n))

        # Clamp counts to available examples and ensure non-negative
        count_cp = max(min(count_cp, len(positives)), 0)
        count_ip = max(min(count_ip, len(positives)), 0)
        count_ing = max(min(count_ing, len(negatives)), 0)

        # Process all examples
        correct_examples = [self._highlight_correct(ex) for ex in positives]
        incorrect_pos_examples = [self._highlight_incorrect(ex) for ex in positives]
        incorrect_neg_examples = [self._highlight_random(ex) for ex in negatives]

        # Shuffle each list to ensure random sampling
        random.shuffle(correct_examples)
        random.shuffle(incorrect_pos_examples)
        random.shuffle(incorrect_neg_examples)

        # Sample the required counts
        correct_sampled = correct_examples[:count_cp]
        incorrect_pos_sampled = incorrect_pos_examples[:count_ip]
        incorrect_neg_sampled = incorrect_neg_examples[:count_ing]

        # Combine examples and labels
        prepared = (
            [ex[0] for ex in correct_sampled] +
            incorrect_pos_sampled +
            incorrect_neg_sampled
        )
        new_labels = (
            [True] * len(correct_sampled) +
            [False] * len(incorrect_pos_sampled) +
            [False] * len(incorrect_neg_sampled)
        )

        # Shuffle while maintaining alignment
        combined = list(zip(prepared, new_labels))
        random.shuffle(combined)
        if not combined:
            return ([], [])
        prepared_shuffled, labels_shuffled = zip(*combined)
        return list(prepared_shuffled), list(labels_shuffled)

    def _process_example(self, ex: Any) -> Tuple[List[str], List[int], List[float]]:
        """Truncate, decode, and filter valid activation positions/values"""
        tokens = ex.context[:self.max_example_length]
        decoded_tokens = [self.context_tokenizer.decode([t]) for t in tokens]
        valid_act_pos = [pos for pos in ex.activation_positions if pos < len(tokens)]
        valid_act_vals = [
            ex.activation_values[i] for i, pos in enumerate(ex.activation_positions) if pos < len(tokens)
        ]
        return decoded_tokens, valid_act_pos, valid_act_vals

    def _highlight_correct(self, ex: Any) -> Tuple[str, int]:
        """Highlight all correct tokens based on activation strength"""
        decoded, valid_pos, valid_vals = self._process_example(ex)
        if not valid_pos:
            return self._apply_highlights(decoded, []), 0

        highlights = valid_pos
        return self._apply_highlights(decoded, highlights), len(highlights)

    def _highlight_incorrect(self, ex: Any) -> str:
        """Mix correct and incorrect highlights"""
        decoded, valid_pos, valid_vals = self._process_example(ex)
        if not valid_pos:
            return self._apply_highlights(decoded, random.sample(range(len(decoded)), max(1, len(decoded) // 5)))

        threshold = np.mean(valid_vals)
        correct = [pos for pos, val in zip(valid_pos, valid_vals) if val > threshold]

        total_highlights = random.randint(1, 5)
        num_correct = int(total_highlights * self.incorrect_pos_ratio)
        num_incorrect = max(1, total_highlights - num_correct)

        selected_correct = correct[:num_correct] if correct else []
        incorrect_pool = list(set(range(len(decoded))) - set(valid_pos))
        selected_incorrect = random.sample(incorrect_pool, min(num_incorrect, len(incorrect_pool)))

        return self._apply_highlights(decoded, selected_correct + selected_incorrect)

    def _highlight_random(self, ex: Any) -> str:
        """Random highlights for negative examples"""
        decoded, _, _ = self._process_example(ex)
        num = random.randint(1, 5)
        return self._apply_highlights(decoded, random.sample(range(len(decoded)), num))

    def _apply_highlights(self, tokens: List[str], positions: List[int]) -> str:
        """Apply << >> highlighting to specified token positions"""
        return "".join([f"<<{t}>>" if i in positions else t for i, t in enumerate(tokens)])