import logging
import os
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path

import pandas as pd
from transformers import PreTrainedModel, PreTrainedTokenizer, TrainerCallback

from data.synthetic_strings import load_data
from data.synthetic_strings.conditional_random import (
    ConditionalRandomStringConfig,
    create_ngram_conditional_dataset,
)
from data.synthetic_strings.deterministic_rules import (
    DeterministicRuleStringConfig,
    DeterministicRuleStringData,
)
from data.synthetic_strings.random import RandomStringConfig, RandomStringData
from defs import EXPERIMENT_SEEDS
from lib_dl_base.defs.task_id import TaskID
from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics
from lib_llm.eval.metrics import TokenEvaluationTask
from lib_llm.models import ModelConfig, load_model_tokenizer
from lib_llm.training import TrainingArguments, TrainingConfig, train


logger = logging.getLogger(__name__)
FILE_DIR = Path(__file__).parent
ROOT_PATH = FILE_DIR.parent.parent.parent

# CONTAINER_TYPE = os.environ.get("CONTAINER_TYPE", "runtime")
USE_CPU_OFFLOAD = os.environ.get("CPU_OFFLOAD", "0").lower() == "1"
if USE_CPU_OFFLOAD:
    logger.info("Using deepspeed offloading for finetuning")
    MEMORIZATION_DEEPSPEED_CONFIG = str(
        FILE_DIR / "ds_memorization_offload_config.json"
    )
else:
    MEMORIZATION_DEEPSPEED_CONFIG = str(
        FILE_DIR / "ds_memorization_config.json"
    )
    # MEMORIZATION_DEEPSPEED_CONFIG = None


@dataclass
class FreezeConfig:
    bias_only: bool = False
    attention_only: bool = False
    mlp_only: bool = False
    layernorm_only: bool = False
    embed_in_only: bool = False
    embed_out_only: bool = False
    layer_x_only: str | None = None

    @property
    def id_postfix(self) -> str:
        if self.bias_only:
            return "_bias_only"
        if self.attention_only:
            return "_attention_only"
        if self.mlp_only:
            return "_mlp_only"
        if self.layernorm_only:
            return "_layernorm_only"
        if self.embed_in_only:
            return "_embed_in_only"
        if self.embed_out_only:
            return "_embed_out_only"
        if self.layer_x_only is not None:
            return f"_layer_{self.layer_x_only}_only"
        raise ValueError("No freeze config specified")

    def freeze_model(self, model: PreTrainedModel) -> PreTrainedModel:
        trained_params = []
        for name, param in model.named_parameters():
            param.requires_grad = self._get_non_freeze_id() in name
            if param.requires_grad:
                trained_params.append(name)
        # Print all parameters that are being trained
        logger.info(f"Training parameters: {trained_params}")
        return model

    def _get_non_freeze_id(self) -> str:
        if self.bias_only:
            return "bias"
        if self.attention_only:
            return ".attention."
        if self.mlp_only:
            return ".mlp."
        if self.layernorm_only:
            return "layernorm"
        if self.embed_in_only:
            return "embed_in"
        if self.embed_out_only:
            return "embed_out"
        if self.layer_x_only is not None:
            if self.layer_x_only == "":
                raise ValueError("layer_x_only must not be empty")
            return f"layers.{self.layer_x_only}."
        raise ValueError("No freeze config specified")


@dataclass
class MemorizationConfig:
    training: TrainingConfig
    is_untrained: bool = False
    freeze: FreezeConfig | None = None


@dataclass
class MemorizationTrainingResult:
    model: PreTrainedModel
    tokenizer: PreTrainedTokenizer
    data: RandomStringData | DeterministicRuleStringData
    training_log: pd.DataFrame
    memorization_log: pd.DataFrame


# Learning rates were tested in the MHR experiment, with a linear
# learning rate schedule
LEARNING_RATE_OVERRIDES = {
    "pythia-70m": 5e-5,
    "gpt2-124m": 5e-4,
    "gpt2": 5e-4,
    "gpt2-1.5b": 1e-4,
    "gpt2-xl": 1e-4,
    "phi-1": 1e-4,
    "phi-1_5": 5e-5,
    "phi-2": 5e-5,
    "opt-350m": 1e-4,
}
# Works well for Pythia 1B and above models
DEFAULT_LEARNING_RATE = 1e-5


def get_memorization_training_config(
    seed_id: int,
    model_id: str,
    num_epochs: int,
    wandb_project_name: str,
    batch_size: int = 1,
    save_final_checkpoint: bool = False,
    is_pretrained: bool = True,
    use_wandb: bool = True,
) -> TrainingConfig:
    if model_id in LEARNING_RATE_OVERRIDES:
        learning_rate = LEARNING_RATE_OVERRIDES[model_id]
    else:
        learning_rate = DEFAULT_LEARNING_RATE
    if not is_pretrained:
        learning_rate *= 10

    return TrainingConfig(
        seed=EXPERIMENT_SEEDS[seed_id],
        train=True,
        args=TrainingArguments(
            learning_rate=learning_rate,
            optim="adamw_torch",
            lr_scheduler_type="linear",
            num_train_epochs=num_epochs,
            deepspeed=MEMORIZATION_DEEPSPEED_CONFIG,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            evaluation_strategy="epoch",
            eval_steps=1,
            logging_steps=10,
            save_strategy="no",
            # max_grad_norm=1.0,
            # bf16=True,
            # bf16_full_eval=True,
            report_to=["wandb"] if use_wandb else ["none"],
        ),
        save_final_checkpoint=save_final_checkpoint,
        wandb_project_name=wandb_project_name,
    )


def get_memorized_model_tokenizer(
    config: MemorizationConfig,
    data_config: RandomStringConfig | DeterministicRuleStringConfig,
    model_config: ModelConfig,
    task_id: TaskID,
    override_eval_task: TokenEvaluationTask | None = None,
    additional_callbacks: list[TrainerCallback] = [],
) -> MemorizationTrainingResult:
    model_config = deepcopy(model_config)
    # if config.use_previous:
    #     model_config.pretrained = False
    model, tokenizer = load_model_tokenizer(model_config)

    if isinstance(data_config, ConditionalRandomStringConfig):
        data = create_ngram_conditional_dataset(data_config, tokenizer)
    else:
        data = load_data(data_config, tokenizer)
    dataset = data.dataset()
    logger.info(f"Generated {len(dataset['test'])} sequences")

    if override_eval_task is not None:
        eval_task = override_eval_task
    else:
        eval_task = memorization_dynamics_metrics(
            data.alphabet_tokens,
            data.alphabet_token_ids,
            dataset["test"],
        )

    model_name = model_config.name
    if config.freeze is not None:
        model_name += config.freeze.id_postfix
        model = config.freeze.freeze_model(model)

    training_res = train(
        task_id,
        (model_name, model),
        (data_config.name, dataset),
        config=config.training,
        tokenizer=tokenizer,
        callbacks=[eval_task, *additional_callbacks],
        data_already_preprocessed=True,
        set_subdirs=[],
    )
    model = training_res.model
    model.name_or_path = model_name
    training_log = training_res.training_log
    assert training_log is not None
    memorization_log = eval_task.result()

    return MemorizationTrainingResult(
        model=model,
        tokenizer=tokenizer,
        data=data,
        training_log=training_log,
        memorization_log=memorization_log,
    )
