import json
import logging
import os
from dataclasses import dataclass, replace
from pathlib import Path
from typing import cast

import numpy as np
import pandas as pd
import torch
from datasets import Dataset, DatasetDict
from torchmetrics import Metric
from transformers import (
    BatchEncoding,
    PreTrainedModel,
    PreTrainedTokenizer,
    TrainerCallback,
)

from data.synthetic_strings.random import RandomStringConfig
from data.synthetic_strings.utils import uses_bos_token
from defs import LLMExperimentConfig
from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics
from lib_llm.eval.metrics import (
    CorrectnessMetric,
    LossMetric,
    MetricArg,
    SequenceEvaluationTask,
    SequenceMetric,
    TokenEvaluationTask,
    TokenMetric,
)
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.ops.batch_size import get_batch_size
from lib_llm.training.train import TrainingConfig, train
from lib_project.experiment import ExperimentID, NoSave, NoSaveValue, experiment
from utils.memorization.memorization import MemorizationTrainingResult


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()
NUM_DEVICES = torch.cuda.device_count() if HAS_CUDA else 1
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "-1"))

SCRIPT_PATH = Path(__file__).parent.resolve()
NAMES_FILE = SCRIPT_PATH / "data" / "first_names.json"
SENTENCE_TEMPLATE_PATH = SCRIPT_PATH / "data" / "sentence_templates.json"

EXP_NAME = "random_facts"
EXP_ABBREVIATION = "rf"


@dataclass
class RandomFactsConfig(RandomStringConfig):
    num_test_samples: int = 1
    # num_tokens means the number of sentences here

    @property
    def name(self) -> str:
        return f"a-{self.alphabet_size}_l-{self.num_tokens}"


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    data: RandomFactsConfig
    model: ModelConfig
    training: TrainingConfig


@dataclass
class ExperimentResult:
    names: list[str]
    sentence_groups: list[int]
    training_log: pd.DataFrame
    memorization_log: pd.DataFrame
    model: NoSave[PreTrainedModel]
    tokenizer: NoSave[PreTrainedTokenizer]


@experiment(EXP_NAME)
def rf_experiment(
    config: ExperimentConfig,
    experiment_id: ExperimentID,
) -> ExperimentResult:
    model, tokenizer = load_model_tokenizer(config.model)
    tokenizer.model_max_length = 2048
    data = create_random_facts_dataset(
        config.data,
        int(config.training.args.num_train_epochs),
        tokenizer,
    )
    training_res = train_model(
        experiment_id,
        config.training,
        (config.model.model_id_not_none, model),
        tokenizer,
        data_name=config.data.name,
        data=data.data,
    )
    training_log = training_res.training_log
    assert training_log is not None
    memorization_log = training_res.memorization_log
    assert memorization_log is not None

    return ExperimentResult(
        names=data.names,
        sentence_groups=data.sentence_groups,
        training_log=training_res.training_log,
        memorization_log=training_res.memorization_log,
        model=NoSaveValue(training_res.model),
        tokenizer=NoSaveValue(training_res.tokenizer),
    )


@dataclass
class RandomFactsData:
    names: list[str]
    sentence_groups: list[int]
    data: DatasetDict


def create_random_facts_dataset(
    data_config: RandomFactsConfig,
    num_epochs: int,
    tokenizer: PreTrainedTokenizer,
) -> RandomFactsData:
    """Create a dataset with different sentences containing the same
    random facts."""
    rng = np.random.default_rng(data_config.seed)

    with open(NAMES_FILE, "r") as f:
        all_names = json.load(f)
    # Create an alphabet of names to choose from
    name_alphabet = rng.choice(
        all_names,
        size=data_config.alphabet_size,
        replace=False,
    )
    # Sample an actual sequence of names. These are the random facts
    # that the model should memorize.
    num_sentences_per_input = data_config.num_tokens
    name_sequence = rng.choice(
        name_alphabet,
        size=num_sentences_per_input + 1,
        replace=True,
    )

    with open(SENTENCE_TEMPLATE_PATH, "r") as f:
        sentence_templates = json.load(f)
    _validate_templates(sentence_templates)
    # Pick a subset of sentence groups in a specific order to sample
    # sentences from
    sentence_indices = np.mod(
        rng.permutation(max(len(sentence_templates), num_sentences_per_input)),
        len(sentence_templates),
    )[:num_sentences_per_input]
    print("sentence_indices", sentence_indices)
    sentence_groups = [sentence_templates[i] for i in sentence_indices]

    inputs = []
    input_ids = []
    attention_mask = []
    name_mask = []
    num_total_sentences = num_epochs + data_config.num_test_samples
    for _ in range(num_total_sentences):
        seq_input = []
        seq_input_ids = []
        seq_attention_mask = []
        seq_name_mask = []

        prev_name = name_sequence[0]
        for j, (sentence_group, cur_name) in enumerate(
            zip(sentence_groups, name_sequence[1:])
        ):
            sentence = rng.choice(sentence_group)
            if j > 0:
                sentence = f" {sentence}"
            (
                sen_sentence,
                sen_input_ids,
                sen_attention_mask,
                sen_name_mask,
            ) = _insert_names(
                sentence,
                prev_name,
                cur_name,
                tokenizer,
                data_config.tokenizer_type,
            )
            seq_input.append(sen_sentence)
            seq_input_ids.append(sen_input_ids)
            seq_attention_mask.append(sen_attention_mask)
            seq_name_mask.append(sen_name_mask)
            prev_name = cur_name
        inputs.append("".join(seq_input))
        input_ids.append(torch.cat(seq_input_ids))
        attention_mask.append(torch.cat(seq_attention_mask))
        name_mask.append(torch.cat(seq_name_mask))

    # Add padding
    max_len = max(len(ids) for ids in input_ids)
    if max_len > tokenizer.model_max_length:
        raise ValueError(
            f"Maximum length of {max_len} exceeds the maximum length of "
            f"the model ({tokenizer.model_max_length})."
        )
    for i in range(num_total_sentences):
        pad_len = max_len - len(input_ids[i])
        input_ids[i] = torch.cat(
            [input_ids[i], torch.zeros(pad_len, dtype=torch.long)]
        )
        attention_mask[i] = torch.cat(
            [attention_mask[i], torch.zeros(pad_len, dtype=torch.long)]
        )
        name_mask[i] = torch.cat(
            [name_mask[i], torch.zeros(pad_len, dtype=torch.long)]
        )

    train_dataset = Dataset.from_dict(
        {
            "text": inputs[data_config.num_test_samples :],
            "input_ids": input_ids[data_config.num_test_samples :],
            "attention_mask": attention_mask[data_config.num_test_samples :],
            "name_mask": name_mask[data_config.num_test_samples :],
        }
    )
    train_dataset.set_format("torch")
    test_dataset = Dataset.from_dict(
        {
            "text": inputs[: data_config.num_test_samples],
            "input_ids": input_ids[: data_config.num_test_samples],
            "attention_mask": attention_mask[: data_config.num_test_samples],
            "name_mask": name_mask[: data_config.num_test_samples],
        }
    )
    test_dataset.set_format("torch")
    data = DatasetDict(
        {
            "train": train_dataset,
            "validation": None,
            "test": test_dataset,
        }
    )
    return RandomFactsData(
        names=name_alphabet.tolist(),
        sentence_groups=sentence_indices.tolist(),
        data=data,
    )


def _validate_templates(
    sentence_templates: list[list[str]],
) -> None:
    for group in sentence_templates:
        if len(group) != 16:
            raise ValueError(
                "Each group of sentence templates must contain 16 sentences."
            )
        for sentence in group:
            if "<x>" not in sentence or "<y>" not in sentence:
                raise ValueError(
                    "Each sentence template must contain placeholders "
                    "<x> and <y>."
                )


def _insert_names(
    sentence_template: str,
    name_1: str,
    name_2: str,
    tokenizer: PreTrainedTokenizer,
    tokenizer_type: str,
) -> tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    split_sentence = _split_on_placeholders(sentence_template)

    sentence_pieces = []
    token_ids = []
    attention_mask = []
    name_mask = []
    for substr in split_sentence:
        is_name = True
        if substr.endswith("<x>") or substr.endswith("<y>"):
            substr = _replace_substr(
                substr,
                {
                    "<x>": name_1,
                    "<y>": name_2,
                },
            )
        else:
            is_name = False

        tokenized_substr = tokenizer(
            [substr],
            return_tensors="pt",
        )
        if uses_bos_token(tokenizer_type):
            offset = 1
        else:
            offset = 0
        sentence_pieces.append(substr)
        input_ids = tokenized_substr.input_ids[0][offset:]
        token_ids.append(input_ids)
        attention_mask.append(tokenized_substr.attention_mask[0][offset:])
        if is_name:
            name_mask.append(torch.ones_like(input_ids))
        else:
            name_mask.append(torch.zeros_like(input_ids))

    if uses_bos_token(tokenizer_type):
        token_ids.insert(0, torch.tensor([tokenizer.bos_token_id]))
        attention_mask.insert(0, torch.tensor([1]))
        name_mask.insert(0, torch.tensor([0]))

    return (
        "".join(sentence_pieces),
        torch.cat(token_ids),
        torch.cat(attention_mask),
        torch.cat(name_mask),
    )


def _split_on_placeholders(
    sentence: str,
) -> list[str]:
    positions = []
    lengths = []
    for placeholder in ["<x>", "<y>"]:
        if f" {placeholder}" in sentence:
            placeholder = f" {placeholder}"
        position = sentence.find(placeholder)
        positions.append(position)
        lengths.append(len(placeholder))
    sort_order = np.argsort(positions)
    positions = np.array(positions)[sort_order].tolist()
    lengths = np.array(lengths)[sort_order].tolist()

    splits = []
    prev_pos = 0
    for pos, length in zip(
        positions + [len(sentence)],
        lengths + [0],
    ):
        split = sentence[prev_pos:pos]
        prev_pos = pos + length
        if len(split) > 0:
            splits.append(split)
        placeholder = sentence[pos : pos + length]
        if len(placeholder) > 0:
            splits.append(placeholder)
    return splits


def _replace_substr(
    sentence: str,
    mapping: dict[str, str],
) -> str:
    for key, value in mapping.items():
        sentence = sentence.replace(key, value)
    return sentence


def train_model(
    experiment_id: ExperimentID,
    training_config: TrainingConfig,
    model_info: tuple[str, PreTrainedModel],
    tokenizer: PreTrainedTokenizer,
    data_name: str,
    data: DatasetDict,
) -> MemorizationTrainingResult:
    memorization_metrics = metrics(data["test"])

    # We control the number of epochs via the size of the context
    # that we repeatedly inject the same string into
    training_config = replace(
        training_config,
        args=replace(
            training_config.args,
            num_train_epochs=1,
            per_device_eval_batch_size=len(data["test"]),
            # per_device_train_batch_size=1,
            eval_steps=1,
            evaluation_strategy="steps",
        ),
    )
    training_res = train(
        experiment_id,
        model_info,
        (data_name, data),
        config=training_config,
        tokenizer=tokenizer,
        callbacks=[memorization_metrics],
        data_already_preprocessed=True,
    )
    training_log = training_res.training_log
    assert training_log is not None
    memorization_log = memorization_metrics.result()
    return MemorizationTrainingResult(
        data=data,
        model=training_res.model,
        tokenizer=tokenizer,
        training_log=training_log,
        memorization_log=memorization_log,
    )


def metrics(
    test_dataset: Dataset,
) -> SequenceEvaluationTask:
    strings_encoding = BatchEncoding(
        dict(
            input_ids=test_dataset["input_ids"],
            attention_mask=test_dataset["attention_mask"],
        )
    )
    name_mask = test_dataset["name_mask"]
    metrics = {
        "loss": MaskedAggregateMetric(
            LossMetric(strings_encoding),
        ),
        "correct": MaskedAggregateMetric(
            CorrectnessMetric(strings_encoding),
        ),
        "name_loss": MaskedAggregateMetric(
            LossMetric(strings_encoding),
            mask=name_mask,
        ),
        "name_correct": MaskedAggregateMetric(
            CorrectnessMetric(strings_encoding),
            mask=name_mask,
        ),
    }

    eval_task = SequenceEvaluationTask(
        metrics,
        data=test_dataset,
        index_names=["string"],
    )
    return eval_task


class MaskedAggregateMetric(SequenceMetric):
    def __init__(
        self,
        wrapped_metric: TokenMetric,
        mask: torch.Tensor | None = None,
    ) -> None:
        super().__init__(wrapped_metric.required_args)
        if mask is not None:
            # The inference code removes the last token, so remove it too here,
            # to make sure that the shape matches
            # self.mask = mask[:, :-1].to(torch.float32)
            self.mask = mask[:, 1:]
            self.masked_count = torch.sum(self.mask, dim=1)
            # self.use_indices = mask.nonzero(as_tuple=True)
        else:
            self.mask = None
        self.wrapped_metric = wrapped_metric

    def update(self, *args, **kwargs) -> None:
        self.wrapped_metric.update(*args, **kwargs)

    def compute(self) -> torch.Tensor:
        wrapped_result = self.wrapped_metric.compute().to(torch.float32)
        self.wrapped_metric.reset()
        if self.mask is None:
            return torch.mean(wrapped_result, dim=1)
        else:
            masked_value = torch.sum(wrapped_result * self.mask, dim=1)
            return masked_value / self.masked_count
