from transformers import AutoConfig, LlamaConfig, Qwen2Config
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config

# TODO: remove fix_memory & memory_update_steps


class MemLlamaConfig(LlamaConfig):
    model_type = 'mem_llama'

    def __init__(
        self,
        *args,
        memory_insert_layers=None,
        update_memory=True,
        memory_size=0,
        use_gpu_to_search=True,
        mem_hidden_path=None,
        fix_memory=False,
        memory_update_steps=1,
        fusion_func=None,
        use_last_prompt_token_as_key=False,
        update_while_predicting=False,
        update_strategy="fifo",
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.architectures = ['MemLlamaForCausalLM']

        if memory_insert_layers is not None:
            try:
                if memory_insert_layers == 'all':
                    memory_insert_layers = [idx for idx in range(self.num_hidden_layers)]
                elif memory_insert_layers == 'alternate':
                    memory_insert_layers = [idx for idx in range(1, self.num_hidden_layers, 2)]
                else:
                    memory_insert_layers = [int(idx.strip()) for idx in memory_insert_layers.split(',')]
            except:
                raise ValueError(
                    f'memory_insert_layers must be either "all", "alternate" or a comma-separated string of '
                    f'integers. Got {memory_insert_layers} instead.'
                )
        self.memory_insert_layers = memory_insert_layers

        self.update_memory = update_memory
        self.memory_size = memory_size
        self.use_gpu_to_search = use_gpu_to_search
        self.mem_hidden_path = mem_hidden_path

        self.fix_memory = fix_memory

        self.memory_update_steps = memory_update_steps

        self.fusion_func = fusion_func

        self.use_last_prompt_token_as_key = use_last_prompt_token_as_key

        self.update_while_predicting = update_while_predicting

        self.update_strategy = update_strategy

    def print_configs(self):
        print()
        print(' MeM Configuration '.center(50, '—'))
        print(f'memory_insert_layers: {self.memory_insert_layers}')
        print(f'update_memory: {self.update_memory}')
        print(f'memory_size: {self.memory_size}')
        print(f'use_gpu_to_search: {self.use_gpu_to_search}')
        print(f'mem_hidden_path: {self.mem_hidden_path}')
        print(f'fix_memory: {self.fix_memory}')
        print(f'memory_update_steps: {self.memory_update_steps}')
        print(f'fusion_func: {self.fusion_func}')
        print(f'use_last_prompt_token_as_key: {self.use_last_prompt_token_as_key}')
        print(f'update_while_predicting: {self.update_while_predicting}')
        print(f'update_strategy: {self.update_strategy}')
        print()


AutoConfig.register('mem_llama', MemLlamaConfig)


class MemQwen2Config(Qwen2Config):
    model_type = 'mem_qwen2'

    def __init__(
        self,
        *args,
        memory_insert_layers=None,
        update_memory=True,
        memory_size=0,
        use_gpu_to_search=True,
        mem_hidden_path=None,
        fix_memory=False,
        memory_update_steps=1,
        fusion_func=None,
        use_last_prompt_token_as_key=False,
        update_while_predicting=False,
        update_strategy="fifo",
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.architectures = ['MemQwen2ForCausalLM']

        if memory_insert_layers is not None:
            try:
                if memory_insert_layers == 'all':
                    memory_insert_layers = [idx for idx in range(self.num_hidden_layers)]
                elif memory_insert_layers == 'alternate':
                    memory_insert_layers = [idx for idx in range(1, self.num_hidden_layers, 2)]
                else:
                    memory_insert_layers = [int(idx.strip()) for idx in memory_insert_layers.split(',')]
            except:
                raise ValueError(
                    f'memory_insert_layers must be either "all", "alternate" or a comma-separated string of '
                    f'integers. Got {memory_insert_layers} instead.'
                )
        self.memory_insert_layers = memory_insert_layers

        self.update_memory = update_memory
        self.memory_size = memory_size
        self.use_gpu_to_search = use_gpu_to_search
        self.mem_hidden_path = mem_hidden_path

        self.fix_memory = fix_memory

        self.memory_update_steps = memory_update_steps

        self.fusion_func = fusion_func

        self.use_last_prompt_token_as_key = use_last_prompt_token_as_key

        self.update_while_predicting = update_while_predicting

        self.update_strategy = update_strategy

    def print_configs(self):
        print()
        print(' MeM Configuration '.center(50, '—'))
        print(f'memory_insert_layers: {self.memory_insert_layers}')
        print(f'update_memory: {self.update_memory}')
        print(f'memory_size: {self.memory_size}')
        print(f'use_gpu_to_search: {self.use_gpu_to_search}')
        print(f'mem_hidden_path: {self.mem_hidden_path}')
        print(f'fix_memory: {self.fix_memory}')
        print(f'memory_update_steps: {self.memory_update_steps}')
        print(f'fusion_func: {self.fusion_func}')
        print(f'use_last_prompt_token_as_key: {self.use_last_prompt_token_as_key}')
        print(f'update_while_predicting: {self.update_while_predicting}')
        print(f'update_strategy: {self.update_strategy}')
        print()


AutoConfig.register('mem_qwen2', MemQwen2Config)


class MemQwen3Config(Qwen3Config):
    model_type = 'mem_qwen3'

    def __init__(
        self,
        *args,
        memory_insert_layers=None,
        update_memory=True,
        memory_size=0,
        use_gpu_to_search=True,
        mem_hidden_path=None,
        fix_memory=False,
        memory_update_steps=1,
        fusion_func=None,
        use_last_prompt_token_as_key=False,
        update_while_predicting=False,
        update_strategy="fifo",
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.architectures = ['MemQwen3ForCausalLM']

        if memory_insert_layers is not None:
            try:
                if memory_insert_layers == 'all':
                    memory_insert_layers = [idx for idx in range(self.num_hidden_layers)]
                elif memory_insert_layers == 'alternate':
                    memory_insert_layers = [idx for idx in range(1, self.num_hidden_layers, 2)]
                else:
                    memory_insert_layers = [int(idx.strip()) for idx in memory_insert_layers.split(',')]
            except:
                raise ValueError(
                    f'memory_insert_layers must be either "all", "alternate" or a comma-separated string of '
                    f'integers. Got {memory_insert_layers} instead.'
                )
        self.memory_insert_layers = memory_insert_layers

        self.update_memory = update_memory
        self.memory_size = memory_size
        self.use_gpu_to_search = use_gpu_to_search
        self.mem_hidden_path = mem_hidden_path

        self.fix_memory = fix_memory

        self.memory_update_steps = memory_update_steps

        self.fusion_func = fusion_func

        self.use_last_prompt_token_as_key = use_last_prompt_token_as_key

        self.update_while_predicting = update_while_predicting

        self.update_strategy = update_strategy

    def print_configs(self):
        print()
        print(' MeM Qwen3 Configuration '.center(50, '—'))
        print(f'memory_insert_layers: {self.memory_insert_layers}')
        print(f'update_memory: {self.update_memory}')
        print(f'memory_size: {self.memory_size}')
        print(f'use_gpu_to_search: {self.use_gpu_to_search}')
        print(f'mem_hidden_path: {self.mem_hidden_path}')
        print(f'fix_memory: {self.fix_memory}')
        print(f'memory_update_steps: {self.memory_update_steps}')
        print(f'fusion_func: {self.fusion_func}')
        print(f'use_last_prompt_token_as_key: {self.use_last_prompt_token_as_key}')
        print(f'update_while_predicting: {self.update_while_predicting}')
        print(f'update_strategy: {self.update_strategy}')
        print()


AutoConfig.register('mem_qwen3', MemQwen3Config)
