import logging
from dataclasses import dataclass, replace
from typing import Iterable

import numpy as np
import pandas as pd
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer, TrainerCallback

from data.synthetic_strings.random import (
    RandomStringConfig,
    RandomStringData,
    generate_random_strings,
)
from defs import SEED_OFFSET, LLMExperimentConfig
from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics
from lib_llm.eval.memorization.prefix_mappings import (
    PrefixEvalConfig,
    PrefixEvalTask,
)
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.models.utils import get_tokenizer_type
from lib_llm.training import TrainingConfig, train
from lib_project.experiment import ExperimentID, iterative_experiment
from utils.prefix_mappings import setup_replacements


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

EXP_NAME = "repeated_training"
EXP_ABBREVIATION = "rt"


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    data: list[RandomStringConfig]
    model: ModelConfig
    training: TrainingConfig
    save_after_iterations: list[int] | None = None
    prefix_eval: PrefixEvalConfig | None = None
    prefix_eval_epochs: list[int] | None = None


@dataclass
class IterationResult:
    training_log: pd.DataFrame
    memorization_log: pd.DataFrame
    prefix_mappings: pd.DataFrame | None = None


@dataclass
class ExperimentResult:
    combined_data: RandomStringData
    iteration_results: list[IterationResult]


@iterative_experiment(EXP_NAME)
def rt_experiment(
    config: ExperimentConfig,
    experiment_id: ExperimentID,
) -> Iterable[ExperimentResult]:
    model, tokenizer = load_model_tokenizer(config.model)

    iteration_results = []
    for i, (training_data, eval_data) in enumerate(
        data_iter(config.data, tokenizer)
    ):
        print("Training on data", i)
        res = training_iteration(
            config,
            experiment_id,
            model,
            tokenizer,
            training_data,
            eval_data,
            i,
        )
        iteration_results.append(res)
        yield ExperimentResult(
            combined_data=eval_data,
            iteration_results=iteration_results,
        )


def data_iter(
    data_configs: list[RandomStringConfig],
    tokenizer: PreTrainedTokenizer,
) -> Iterable[tuple[RandomStringData, RandomStringData]]:
    data_repetitions = [
        generate_random_strings(data_config, tokenizer)
        for data_config in data_configs
    ]

    combined_tokens = {}
    max_sq_len = 0
    for training_data in data_repetitions:
        combined_tokens |= dict(
            zip(training_data.alphabet_tokens, training_data.alphabet_token_ids)
        )
        max_sq_len = max(max_sq_len, training_data.raw_token_ids.shape[1])

    eval_data = replace(
        data_repetitions[0],
        raw_token_ids=None,
        raw_tokens=[],
        raw_attention_mask=None,
        alphabet_token_ids=np.array(list(combined_tokens.values())),
        alphabet_tokens=list(combined_tokens.keys()),
    )
    pad_id = tokenizer.pad_token_id
    assert pad_id is not None

    for training_data in data_repetitions:
        padded_training_data_token_ids = _pad_to_length(
            training_data.raw_token_ids, max_sq_len, pad_id
        )
        training_data_attention_mask = np.ones_like(training_data.raw_token_ids)
        padded_training_data_attention_mask = _pad_to_length(
            training_data_attention_mask, max_sq_len, 0
        )

        if eval_data.raw_token_ids is not None:
            eval_raw_token_ids = np.concatenate(
                [
                    _pad_to_length(eval_data.raw_token_ids, max_sq_len, pad_id),
                    padded_training_data_token_ids,
                ],
                axis=0,
            )
            assert eval_data.raw_attention_mask is not None
            eval_raw_attention_mask = np.concatenate(
                [
                    eval_data.raw_attention_mask,
                    padded_training_data_attention_mask,
                ],
                axis=0,
            )

        else:
            eval_raw_token_ids = padded_training_data_token_ids
            eval_raw_attention_mask = padded_training_data_attention_mask

        eval_data = replace(
            eval_data,
            raw_token_ids=eval_raw_token_ids,
            raw_tokens=eval_data.raw_tokens + training_data.raw_tokens,
            raw_attention_mask=eval_raw_attention_mask,
        )
        yield training_data, eval_data


def _pad_to_length(
    sequence: np.ndarray,
    length: int,
    padding_token_id: int,
) -> np.ndarray:
    assert sequence.shape[1] <= length
    padded = np.full((sequence.shape[0], length), padding_token_id)
    padded[:, : sequence.shape[1]] = sequence
    return padded


def training_iteration(
    config: ExperimentConfig,
    experiment_id: ExperimentID,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    training_data: RandomStringData,
    eval_data: RandomStringData,
    iteration: int,
) -> IterationResult:
    training_dataset = training_data.dataset()
    eval_dataset = eval_data.dataset()

    memorization_task = memorization_dynamics_metrics(
        eval_data.alphabet_tokens,
        eval_data.alphabet_token_ids,
        eval_dataset["test"],
    )
    eval_tasks: list[TrainerCallback] = [memorization_task]

    if (
        config.prefix_eval is not None
        and config.prefix_eval_epochs is not None
        and (iteration == 0 or iteration == len(config.data) - 1)
    ):
        prefix_eval_config = replace(
            config.prefix_eval,
            seed=config.prefix_eval.seed + iteration * SEED_OFFSET,
        )
        # Evaluate the prefix length
        get_replacements = setup_replacements(
            prefix_eval_config,
            replacement_strategy="rand_id",
            tokenizer=tokenizer,
            tokenizer_type=get_tokenizer_type(config.model.model_id_not_none),
            replacement_length=len(training_data.token_ids[0]),
        )
        prefix_eval_task = PrefixEvalTask(
            prefix_eval_config,
            data=(training_data.tokens, training_data.batch_encoding()),
            tokenizer=tokenizer,
            get_replacements=get_replacements,
            eval_condition=lambda state: (
                state.epoch is not None
                and int(state.epoch) in config.prefix_eval_epochs
            ),
        )
        eval_tasks.append(prefix_eval_task)
    else:
        prefix_eval_task = None

    if (
        config.save_after_iterations is not None
        and iteration in config.save_after_iterations
    ):
        training_config = replace(
            config.training,
            save_final_checkpoint=True,
        )
    else:
        training_config = config.training

    training_res = train(
        experiment_id,
        (config.model.model_id_not_none, model),
        (f"data_{iteration}", training_dataset),
        config=training_config,
        tokenizer=tokenizer,
        callbacks=eval_tasks,
        data_already_preprocessed=True,
    )
    model = training_res.model

    memorization_log = memorization_task.result()
    trainig_log = training_res.training_log
    assert trainig_log is not None
    iteration_res = IterationResult(
        training_log=trainig_log,
        memorization_log=memorization_log,
    )
    if prefix_eval_task is not None:
        iteration_res.prefix_mappings = prefix_eval_task.result()
    return iteration_res
