import abc
from typing import List, Any, Optional

import transformers
import torch

from patching_gemma import logger

class Llama3Model:
    AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM

    def __init__(self, model_name) -> None:
        self.suitable_batch_size = None
        self.model_name = {
            "llama3": "meta-llama/Llama-3.2-3B",
            "smollm": "HuggingFaceTB/SmolLM2-1.7B",
        }[model_name]
        
    def create_model(self) -> 'Llama3Model':

        pretrained = self.model_name
        hf_kwargs = {}
        hf_kwargs["offload_folder"] = "./offload"
        hf_kwargs["attn_implementation"] = "eager"

        self.model = self.AUTO_MODEL_CLASS.from_pretrained(
            pretrained,
            use_auth_token=True,
            **hf_kwargs,
        ).to("cuda")

        self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})

        self.model_logs = {}
        self.break_into()
        return self
    
    def _log_memory_usage(self) -> None:
        for d in range(torch.cuda.device_count()):
            t = torch.cuda.get_device_properties(d).total_memory
            r = torch.cuda.memory_reserved(d)
            a = torch.cuda.memory_allocated(d)
            logger.debug(f"Device {d}, total_memory: {t/8/1024/1024:.4}Gb, reserved: {r/8/1024/1024:.4}Gb, allocated: {a/8/1024/1024:.4}Gb, free: {(t-r)/8/1024/1024:.4}Gb")

    @abc.abstractmethod
    def run(self, *args, **kwargs) -> None:
        pass

    @abc.abstractmethod
    def break_into(self, **additional_kwargs) -> None:
        pass
        
    @abc.abstractmethod
    def break_out(self) -> None:
        pass