import abc
from typing import List, Any, Optional

import transformers
import torch

from patching_gemma import logger

class Gemma2Model:
    AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM

    def __init__(self) -> None:
        self.suitable_batch_size = None
        
    def create_model(self) -> 'Gemma2Model':

        pretrained = "google/gemma-2-2b"
        hf_kwargs = {}
        hf_kwargs["offload_folder"] = "./offload"

        # There is a bug in sdpa attention with padding
        hf_kwargs["attn_implementation"] = "eager"

        self.model = self.AUTO_MODEL_CLASS.from_pretrained(
            pretrained,
            use_auth_token=True,
            **hf_kwargs,
        ).to("cuda")

        self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})

        self.model_logs = {}
        self.break_into()
        return self
    
    def _log_memory_usage(self) -> None:
        for d in range(torch.cuda.device_count()):
            t = torch.cuda.get_device_properties(d).total_memory
            r = torch.cuda.memory_reserved(d)
            a = torch.cuda.memory_allocated(d)
            logger.debug(f"Device {d}, total_memory: {t/8/1024/1024:.4}Gb, reserved: {r/8/1024/1024:.4}Gb, allocated: {a/8/1024/1024:.4}Gb, free: {(t-r)/8/1024/1024:.4}Gb")

    @abc.abstractmethod
    def run(self, *args, **kwargs) -> None:
        pass

    @abc.abstractmethod
    def break_into(self, **additional_kwargs) -> None:
        pass
        
    @abc.abstractmethod
    def break_out(self) -> None:
        pass