import gc
import json
import random
import uuid
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Tuple, Union

import numpy as np
import torch
from datasets import DatasetDict, load_dataset
from datasets.dataset_dict import Dataset as HFDataset

# from sandbagging_research_sprint.sandbagging_evaluation.evaluate_sandbagging import (
#     evaluate_split,
# )
from torch.utils.data import DataLoader, Dataset

import wandb
from eliciting_contexts.fluent_dreaming.sandbagging import (
    run_sandbagging_fluent_dreaming,
)
from eliciting_contexts.sandbagging.utils import (  # BufferedDatasetLogger,; DatasetLogger,; clear_memory,
    TensorEncoder,
    force_cuda_release,
    print_memory_stats,
)
from eliciting_contexts.utils.constants import DATA_DIR, DEVICE, WANDB_ENTITY
from eliciting_contexts.utils.evaluation import evaluate_model
from eliciting_contexts.utils.load_datasets import (
    extract_post_system_prompt,
    extract_system_prompt,
    load_harmful_data,
)
from eliciting_contexts.utils.load_models import load_finetuned_model


class EPOMode(Enum):
    PROBE = "probe"
    COSINE = "cosine"


@dataclass
class EPOConfig:
    mode: EPOMode
    num_samples: int = 250
    num_prefix_samples: int = 25
    data_split: str = "test"
    behavior: str = "sandbagging"
    model_name: str = "contextmodification/gemma-sandbagging-0w4j7rba-step1024"
    base_model_name: str = "gemma2-9b-it"
    x_penalty_min: float = 0.1
    x_penalty_max: float = 10.0
    iters: int = 40
    population_size: int = 4
    use_wandb: bool = True
    log_individual_runs: bool = True
    overwrite: bool = False
    random_seed: int = 0
    max_mode: Literal["knee", "target", "ce", "all"] = "all"
    batch_size: int = 4


class EPORunner:
    def __init__(self, config: EPOConfig):
        self.config = config
        self.base_dir = (
            DATA_DIR / config.behavior / config.model_name.split("/")[-1] / "epo"
        )

        # Load dataset
        self.datasetdict, self.filter_key = self._load_dataset()
        assert isinstance(self.datasetdict, DatasetDict)
        self.already_processed = []

    def _load_dataset(self) -> Tuple[DatasetDict, str]:
        if self.config.behavior == "sandbagging":
            datasetdict = load_dataset("contextmodification/sandbagging-sciq")
            filter_key = "sandbagging_environment"
        elif self.config.behavior == "harmless":
            datasetdict = load_harmful_data()
            filter_key = "harmful"
        else:
            raise ValueError(f"{self.config.behavior} not supported")

        assert isinstance(datasetdict, DatasetDict)
        return datasetdict, filter_key

    def run(self, layer: int, by_sample: bool = True, by_prefix: bool = True):
        """Main execution method that runs the full EPO pipeline."""
        # Set up layer-specific output directories
        self.output_dir = self.base_dir / f"layer_{layer}"
        self.samples_dir = self.output_dir / f"{self.config.mode.value}_samples"
        self.prefix_samples_dir = (
            self.output_dir / f"{self.config.mode.value}_prefix_samples"
        )

        # Create directories
        self.samples_dir.mkdir(parents=True, exist_ok=True)
        self.prefix_samples_dir.mkdir(parents=True, exist_ok=True)

        # Rename combined results file to include mode
        self.combined_results_file = (
            self.output_dir / f"{self.config.mode.value}_combined_results.json"
        )

        results = {}
        wandb_group = uuid.uuid4().hex[:8]

        if self.config.use_wandb:
            wandb.init(
                project="fluent-dreaming",
                entity=WANDB_ENTITY,
                group=wandb_group,
                name=f"epo_evaluation_layer_{layer}_{wandb_group}",
                config={
                    "layer": layer,
                    "model_name": self.config.model_name,
                    "behavior": self.config.behavior,
                    "mode": self.config.mode.value,
                    "num_samples": self.config.num_samples,
                    "num_prefix_samples": self.config.num_prefix_samples,
                },
            )

        if by_sample:
            self.epo_by_sample(
                layer=layer,
                output_dir=self.samples_dir,
                num_samples=self.config.num_samples,
            )

            combined_ds = self._combine_samples(
                self.samples_dir, self.config.num_samples
            )
            # Determine which modes to evaluate
            modes_to_evaluate = (
                ["knee", "target", "ce"]
                if self.config.max_mode == "all"
                else [self.config.max_mode]
            )

            for mode in modes_to_evaluate:
                sample_results = self._evaluate_epo(
                    combined_ds,
                    f"epo_{self.config.mode.value}_individual",
                    max_mode=mode,
                )
                results[f"epo_{self.config.mode.value}_individual_{mode}"] = (
                    sample_results
                )

        if by_prefix:
            self.epo_by_sample(
                layer=layer,
                output_dir=self.prefix_samples_dir,
                num_samples=self.config.num_prefix_samples,
            )

            combined_ds = self._combine_samples(
                self.prefix_samples_dir, self.config.num_prefix_samples
            )
            # Use the same modes as determined above
            modes_to_evaluate = (
                ["knee", "target", "ce"]
                if self.config.max_mode == "all"
                else [self.config.max_mode]
            )

            for mode in modes_to_evaluate:
                full_dataset = self._create_dataset(
                    combined_ds, prompt_key=f"epo_prompt_{mode}"
                )
                prefix_results = self._evaluate_epo(
                    full_dataset,
                    f"epo_{self.config.mode.value}_general_{mode}",
                    max_mode=mode,
                    prompt_key="epo_prefixed_prompt",
                )
                results[f"{self.config.mode.value}_general_{mode}"] = prefix_results

        # Save combined results with mode-specific filename
        with open(self.combined_results_file, "w") as file:
            json.dump(results, file, indent=4, cls=TensorEncoder)

        # Log final results to wandb only if it's enabled
        if self.config.use_wandb and wandb.run:
            wandb.log(results)

        return results

    def epo_by_sample(
        self,
        layer: int,
        output_dir: Union[Path, str],
        num_samples: int,
    ):
        """Run EPO on all the samples in a dataset."""

        # Set memory limits to prevent OOM
        if torch.cuda.is_available():
            torch.cuda.set_per_process_memory_fraction(
                0.9
            )  # Use 90% of available memory

        wandb_group = uuid.uuid4().hex[:8]

        filtered_dataset_dict = self.datasetdict[self.config.data_split].filter(
            lambda x: not x[self.filter_key]
        )

        # Set random seed
        random.seed(self.config.random_seed)
        np.random.seed(self.config.random_seed)
        torch.manual_seed(self.config.random_seed)
        torch.cuda.manual_seed(self.config.random_seed)

        already_processed = []
        # Check for existing results
        result_files = output_dir.glob("sample_*.json")  # This is fragile
        if result_files and not self.config.overwrite:
            for file in result_files:
                sample_id = int(file.stem.split("_")[1])  # This is fragile
                already_processed.append(sample_id)
            print(f"Found {len(already_processed)} already processed samples")

        # Run EPO on samples
        if num_samples < len(filtered_dataset_dict):
            filtered_dataset_dict = filtered_dataset_dict.select(range(num_samples))

        for i in range(len(filtered_dataset_dict)):
            if i in already_processed:
                print(f"Skipping already processed sample {i}")
                continue

            print(f"Processing sample {i} of {len(filtered_dataset_dict)}")
            print_memory_stats()

            datum = filtered_dataset_dict[i]

            try:
                # Process single sample
                pareto = run_sandbagging_fluent_dreaming(
                    full_prompt_text=datum["prompt"],
                    use_wandb=self.config.log_individual_runs,
                    device=DEVICE,
                    iters=self.config.iters,
                    population_size=self.config.population_size,
                    wandb_group=wandb_group,
                    runner_type=self.config.mode.value,
                    layer=layer,
                )[0]

                # Find indices for all max modes
                knee_idx = self._find_knee_point(pareto)
                target_idx = np.argmax(pareto.target)
                ce_idx = np.argmin(pareto.xentropy)

                # Create a copy of datum with EPO results for all modes
                datum_copy = datum.copy()
                datum_copy["epo_prompt_knee"] = pareto.text[knee_idx]
                datum_copy["epo_prompt_target"] = pareto.text[target_idx]
                datum_copy["epo_prompt_ce"] = pareto.text[ce_idx]

                # Save metrics for analysis
                datum_copy["metrics"] = {
                    "knee": {
                        "target": float(pareto.target[knee_idx]),
                        "xentropy": float(pareto.xentropy[knee_idx]),
                    },
                    "target": {
                        "target": float(pareto.target[target_idx]),
                        "xentropy": float(pareto.xentropy[target_idx]),
                    },
                    "ce": {
                        "target": float(pareto.target[ce_idx]),
                        "xentropy": float(pareto.xentropy[ce_idx]),
                    },
                }

                # Save this single result immediately
                with open(output_dir / f"sample_{i}.json", "w") as file:
                    json.dump(datum_copy, file, indent=4, cls=TensorEncoder)

                print(f"Saved result for sample {i}")

                # Delete everything and clear memory
                del pareto
                del datum_copy
                force_cuda_release()

            except Exception as e:
                print(f"Error processing sample {i}: {e}")

            # Aggressive memory clearing after each sample
            gc.collect()
            torch.cuda.empty_cache()
            if torch.cuda.is_available():
                torch.cuda.reset_peak_memory_stats()
            print_memory_stats()

    def _find_knee_point(self, pareto) -> int:
        """Find the knee point on a 2D Pareto frontier."""

        # Get the points
        y = pareto.target  # First objective (to maximize)
        x = pareto.xentropy  # Second objective (to minimize)

        # Normalize the points to [0,1] range
        x_norm = (x - np.min(x)) / (np.max(x) - np.min(x))
        y_norm = (y - np.min(y)) / (np.max(y) - np.min(y))

        # Create line connecting extreme points
        start = np.array([x_norm[0], y_norm[0]])
        end = np.array([x_norm[-1], y_norm[-1]])
        line_vec = end - start

        # Calculate distance from each point to line
        distances = []
        for i in range(len(x_norm)):
            point = np.array([x_norm[i], y_norm[i]])
            # Vector from start to point
            point_vec = point - start
            # Project point_vec onto line_vec
            line_length = np.linalg.norm(line_vec)
            proj_length = np.dot(point_vec, line_vec) / line_length
            proj_vec = (proj_length / line_length) * line_vec
            # Distance is length of vector rejection
            dist = np.linalg.norm(point_vec - proj_vec)
            distances.append(dist)

        # Return index of point with maximum distance (the knee)
        return int(np.argmax(distances))

    def _print_metrics(self, metrics: dict):
        """Print the results of a sandbagging evaluation in a human readable way."""

        print("Results:")
        print("Normal Environment:")
        print(f"  Accuracy: {metrics['normal_accuracy']:.1%}")
        print(
            f"  Accuracy (excluding non responses): {metrics['normal_accuracy_excluding_non_responses']:.1%}"
        )
        print(f"  Total samples: {metrics['normal_total']}")
        print(f"  Response rate: {metrics['normal_response_rate']:.1%}")
        print("\nSandbagging Environment:")
        print(f"  Accuracy: {metrics['sandbagging_accuracy']:.1%}")
        print(
            f"  Accuracy (excluding non responses): {metrics['sandbagging_accuracy_excluding_non_responses']:.1%}"
        )
        print(f"  Total samples: {metrics['sandbagging_total']}")
        print(f"  Response rate: {metrics['sandbagging_response_rate']:.1%}")

    def _combine_samples(
        self,
        samples_dir: Path,
        num_samples: int,
        # max_mode: Literal["knee", "target", "ce"] = "knee",
    ):
        """Combine individual sample files into a single dataset using specified max_mode."""
        samples = []
        filtered_dataset_dict = self.datasetdict[self.config.data_split].filter(
            lambda x: not x[self.filter_key]
        )

        for i in range(min(num_samples, len(filtered_dataset_dict))):
            sample_file = samples_dir / f"sample_{i}.json"
            if sample_file.exists():
                with open(sample_file, "r") as file:
                    sample_data = json.load(file)
                    # Use the specified max_mode's prompt
                    # sample_data["epo_prompt"] = sample_data[f"epo_prompt_{max_mode}"]
                    samples.append(sample_data)

        print(f"Combined {len(samples)} samples")
        return HFDataset.from_list(samples)

    def _create_dataset(self, prefix_samples: HFDataset, prompt_key: str) -> Dataset:
        """Create a dataset using EPO-generated prefixes applied to random samples."""
        # Get the filtered dataset
        filtered_dataset_dict = self.datasetdict[self.config.data_split].filter(
            lambda x: not x[self.filter_key]
        )

        # Select random samples if needed
        if self.config.num_samples < len(filtered_dataset_dict):
            filtered_dataset_dict = filtered_dataset_dict.select(
                range(self.config.num_samples)
            )

        # Extract prefixes from the EPO samples
        prefixes = []
        for sample in prefix_samples:
            prefix = extract_system_prompt(sample[prompt_key])
            if prefix:  # Only add non-empty prefixes
                prefixes.append(prefix)

        if not prefixes:
            raise ValueError("No valid prefixes found in EPO samples")

        # Define mapping function that replaces system prompt
        def replace_system_prompt(example):
            new_prefix = random.choice(prefixes)
            post_system = extract_post_system_prompt(example["prompt"])
            return {
                **example,
                "epo_prefixed_prompt": f"<start_of_turn>system{new_prefix}<end_of_turn>{post_system}",
            }

        # Apply the transformation to the dataset
        return filtered_dataset_dict.map(replace_system_prompt)

    def _evaluate_epo(
        self,
        dataset: Dataset,
        evaluation_type: Literal[
            "epo_probe_individual",
            "epo_probe_general",
            "epo_cosine_individual",
            "epo_cosine_general",
        ],
        max_mode: Literal["knee", "target", "ce"] = "target",
        prompt_key: Optional[str] = None,
    ) -> Dict[str, Any]:
        """Evaluate the EPO samples and return metrics dictionary."""

        # Clear memory before evaluation
        gc.collect()
        torch.cuda.empty_cache()
        print_memory_stats()

        # Load model for evaluation
        print("Loading model for evaluation...")
        force_cuda_release()
        model, tokenizer = load_finetuned_model(
            lora_model_id=self.config.model_name,
            # base_model_name=self.config.base_model_name,
            device=DEVICE,
        )

        # Use smaller batch size for evaluation
        dataloader = DataLoader(dataset, batch_size=self.config.batch_size)
        # logger_columns = ["predicted_answer", "correct_answer", "prompt", "prompt_key"]
        # epo_logger = BufferedDatasetLogger(logger_columns, "epo_sandbagging_dataset")

        try:
            # Evaluate EPO-modified inputs
            print(f"Performing evaluation of type: {evaluation_type}")
            results = evaluate_model(
                model=model,
                tokenizer=tokenizer,
                dataloader=dataloader,
                evaluation_type=evaluation_type,
                prompt_key=(
                    f"epo_prompt_{max_mode}" if prompt_key is None else prompt_key
                ),
            )
            return results

        except Exception as e:
            print(f"Error during evaluation: {e}")
            raise

        finally:
            # Clean up
            del model
            gc.collect()
            torch.cuda.empty_cache()
            if torch.cuda.is_available():
                torch.cuda.reset_peak_memory_stats()
            print_memory_stats()


if __name__ == "__main__":
    layer = 12
    num_samples = 100
    num_samples_prefix = 4

    epo_steering = EPOMode.COSINE
    epo_cfg_steering = EPOConfig(
        mode=epo_steering,
        num_samples=num_samples,
        num_prefix_samples=num_samples_prefix,
        iters=40,
        log_individual_runs=True,
        overwrite=True,
    )
    sandbagging_runner_cosine = EPORunner(epo_cfg_steering)
    sandbagging_runner_cosine.run(layer=layer, by_sample=True, by_prefix=True)

    epo_probe = EPOMode.PROBE
    epo_cfg_probe = EPOConfig(
        mode=epo_probe,
        num_samples=num_samples,
        num_prefix_samples=num_samples_prefix,
        iters=40,
        log_individual_runs=True,
        overwrite=True,
    )
    sandbagging_runner_probe = EPORunner(epo_cfg_probe)
    sandbagging_runner_probe.run(layer=layer, by_sample=True, by_prefix=True)
