import torch
from transformers import (
    AutoModelForCausalLM, DynamicCache, DynamicLayer, Qwen3Config, LlamaConfig
)

from PyTorch.models.modeling_qwen3 import Qwen3ForCausalLM
from PyTorch.models.modeling_llama import LlamaForCausalLM

AutoModelForCausalLM.register(Qwen3Config, Qwen3ForCausalLM, exist_ok=True)
AutoModelForCausalLM.register(LlamaConfig, LlamaForCausalLM, exist_ok=True)


class PyTorchBackend:
    def __init__(self, dtype = torch.float16, device: str = "cuda:0") -> None:
        self.dtype = dtype
        self.device = device

    def load_model(self, model_name, checkpoint_path, **kwargs):
        pretrained_id = str(checkpoint_path) if checkpoint_path is not None else model_name
        model = AutoModelForCausalLM.from_pretrained(
            pretrained_id,
            dtype=self.dtype,
            attn_implementation=kwargs.get("attn_impl", "eager"),
        )
        
        tp_size = kwargs.get("tp_size", 1)
        if tp_size > 1:
            from PyTorch.tensor_parallel import apply_tp
            print("Applying tensor parallel to model ...")
            self.model = apply_tp(model)
            self.device = self.model.model.embed_tokens.weight.device
        else:
            self.model = model.eval()
        
        if kwargs.get("fuse_weights", False):
            self.model.fuse_weights()

    def setup_caches(self, max_batch_size=1, max_seq_length=2048):
        self.batch_size = max_batch_size
        self.max_seq_length = max_seq_length
        self.cachelens = 0

        self.n_layers = self.model.config.num_hidden_layers
        self.n_kv_head = self.model.config.num_key_value_heads
        self.head_dim = self.model.config.head_dim
        self.kv_cache = DynamicCache()

    def compile(self):
        pass

    @torch.inference_mode()
    def decode(self, input_ids: torch.LongTensor):
        decode_output = self.model(
            input_ids=input_ids,
            past_key_values=self.kv_cache,
            use_cache=True,
        )
        self.kv_cache = decode_output.past_key_values
        self.cachelens = self.kv_cache.get_seq_length()
        return decode_output.logits
    
    
    def insert_kv(self, new_kv_len: int):
        assert new_kv_len >= 0, "new_kv_len must be greater than or equal to 0"
        assert isinstance(self.kv_cache, DynamicCache), "kv_cache must be a DynamicCache"

        if new_kv_len == 0:
            return

        if self.cachelens + new_kv_len > self.max_seq_length:
            raise RuntimeError(f"cachelens ({self.cachelens}) + new_kv_len ({new_kv_len}) exceeds max_seq_length ({self.max_seq_length})")

        random_key_states = torch.randn(self.batch_size, self.n_kv_head, new_kv_len, self.head_dim, dtype=self.dtype, device=self.device)
        random_value_states = torch.randn(self.batch_size, self.n_kv_head, new_kv_len, self.head_dim, dtype=self.dtype, device=self.device)

        new_cache = []
        for layer_idx in range(self.n_layers):
            if self.kv_cache.get_seq_length() == 0:
                new_cache.append((random_key_states, random_value_states))
            else:
                new_cache.append(
                    (
                        torch.cat([self.kv_cache[layer_idx][0], random_key_states], dim=2),
                        torch.cat([self.kv_cache[layer_idx][1], random_value_states], dim=2)
                    )
                )
        self.kv_cache = DynamicCache.from_legacy_cache(tuple(new_cache))
        self.cachelens += new_kv_len

    def delete_kv(self, del_kv_len: int):
        assert self.cachelens >= del_kv_len, "cachelens must be greater than or equal to del_kv_len"
        assert del_kv_len >= 0, "del_kv_len must be greater than or equal to 0"
        assert isinstance(self.kv_cache, DynamicCache), "kv_cache must be a DynamicCache"

        if del_kv_len == 0:
            return

        if self.cachelens - del_kv_len < 0:
            raise RuntimeError(f"cachelens ({self.cachelens}) is less than del_kv_len ({del_kv_len})")

        new_cache = []
        for cache in self.kv_cache:
            new_cache.append(
                (cache[0][..., :-del_kv_len, :], cache[1][..., :-del_kv_len, :])
            )
        self.kv_cache = DynamicCache.from_legacy_cache(tuple(new_cache))
        self.cachelens -= del_kv_len

    def clear_kv(self):
        self.cachelens = 0
        self.kv_cache = DynamicCache()