import logging
from dataclasses import dataclass
from typing import Callable

import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, GPTNeoXForCausalLM, PreTrainedTokenizer

from data.random_strings import (
    RandomStringConfig,
    RandomStringData,
    generate_random_strings,
)
from defs import LLMExperimentConfig
from lib_llm.eval.memorization.dynamics import memorization_dynamics_metrics
from lib_project.experiment import ExperimentID, experiment


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

EXP_NAME = "string_rules"
EXP_ABBREVIATION = "st"


@dataclass
class ExperimentConfig(LLMExperimentConfig):
    seed: int
    data: RandomStringConfig
    # model: ModelConfig


@dataclass
class ExperimentResult:
    data: RandomStringData
    in_context_performance: pd.DataFrame


STEP_SIZE = 20
PYTHIA_STEPS = [10000] + list(range(0, 143000, STEP_SIZE * 1000)) + [143000]
MODEL_ID = "EleutherAI/pythia-1b-deduped"


@experiment(EXP_NAME)
def st_experiment(
    config: ExperimentConfig,
    description: ExperimentID,
) -> ExperimentResult:
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID,
        revision="step0",
    )
    data = generate_random_strings(config.data, tokenizer)
    eval_task = memorization_dynamics_metrics(
        data.alphabet_tokens,
        data.alphabet_token_ids,
        data.dataset()["test"],
    )

    results = []
    for step in PYTHIA_STEPS:
        revision = f"step{step}"
        model = GPTNeoXForCausalLM.from_pretrained(
            MODEL_ID,
            revision=revision,
        )
        performance = eval_task.evaluate(model)
        results.append(performance)

    in_context_performance = pd.concat(
        results,
        axis=0,
        keys=PYTHIA_STEPS,
        names=["step"],
    )
    return ExperimentResult(
        data=data,
        in_context_performance=in_context_performance,
    )
