import os
import re
from typing import Dict, Optional, Tuple

import torch

from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference, dtype_from_name
from mergekit.graph import Task
from mergekit.io.lazy_tensor_loader import LazyTensorLoader
from mergekit.io.tensor_writer import TensorWriter
from mergekit.options import MergeOptions


class LoaderCache:
    loaders: Dict[ModelReference, LazyTensorLoader] = {}
    lora_cache_dir: Optional[str] = None
    hf_cache_dir: Optional[str] = None
    lazy_unpickle: bool = False
    trust_remote_code: bool = False

    # singleton instance
    _instance: Optional["LoaderCache"] = None

    def __new__(cls) -> "LoaderCache":
        if cls._instance is None:
            cls._instance = super(LoaderCache, cls).__new__(cls)
        return cls._instance

    def get(self, model: ModelReference) -> LazyTensorLoader:
        if model not in self.loaders:
            merged = model.merged(
                cache_dir=self.lora_cache_dir, trust_remote_code=self.trust_remote_code
            )
            self.loaders[model] = merged.lazy_loader(
                cache_dir=self.hf_cache_dir, lazy_unpickle=self.lazy_unpickle
            )
        return self.loaders[model]

    def flush_all(self):
        for loader in self.loaders.values():
            loader.flush()

    def setup(self, options: MergeOptions):
        self.lora_cache_dir = options.lora_merge_cache
        self.hf_cache_dir = options.transformers_cache
        self.lazy_unpickle = options.lazy_unpickle
        self.trust_remote_code = options.trust_remote_code


shard_name_re = re.compile(r"model\-([0-9]+)-of-([0-9]+)")


def _normalized_shard_name(path: str) -> int:
    name, _ext = os.path.splitext(os.path.basename(path))
    name = name.lower().replace("pytorch_model", "model")
    if m := shard_name_re.search(name):
        frac = int(m.group(1)) / int(m.group(2))
        name = f"model-{int(frac*100):03d}pct"
    return name


class LoadTensor(Task[Optional[torch.Tensor]]):
    model: ModelReference
    tensor: str
    dtype: Optional[str] = None
    device: Optional[str] = None
    optional: bool = False
    aliases: Optional[Tuple[str, ...]] = None

    def arguments(self) -> Dict[str, Task]:
        return {}

    def _resolve_name(self, loader: LazyTensorLoader) -> Optional[str]:
        all_names = [self.tensor] + list(self.aliases or [])
        for name in all_names:
            if name in loader.index.tensor_paths:
                return name
        return None

    def execute(self) -> Optional[torch.Tensor]:
        loader = LoaderCache().get(self.model)
        name = self._resolve_name(loader)
        if not name:
            if not self.optional:
                raise RuntimeError(
                    f"Tensor {self.tensor} required but not present in model {self.model}"
                )
            return None

        x = loader.get_tensor(name, device=self.device or "cpu")
        if self.dtype and (dtype := dtype_from_name(self.dtype)) != x.dtype:
            x = x.to(dtype=dtype)
        return x

    def priority(self) -> int:
        return -1000

    def group_label(self) -> Optional[str]:
        loader = LoaderCache().get(self.model)
        name = self._resolve_name(loader)
        # if name:
        #     shard_path = loader.index.tensor_paths[name]
        #     return _normalized_shard_name(shard_path)
        # return None
        return name


class GatherTensors(Task[Dict[ModelReference, torch.Tensor]]):
    weight_info: ImmutableMap[ModelReference, WeightInfo]
    dtype: Optional[str] = None
    device: Optional[str] = None

    def arguments(self) -> Dict[str, Task]:
        return {
            f"{str(model)}:{wi.name}": LoadTensor(
                model=model,
                tensor=wi.name,
                dtype=wi.force_dtype or self.dtype,
                device=self.device,
                optional=wi.optional,
                aliases=wi.aliases,
            )
            for (model, wi) in self.weight_info.items()
        }

    def group_label(self) -> Optional[str]:
        return max(t.group_label() or "" for t in self.arguments().values())

    def priority(self) -> int:
        return -10

    def execute(self, **kwargs) -> Dict[ModelReference, torch.Tensor]:
        key2model = {
            f"{str(model)}:{wi.name}": model for (model, wi) in self.weight_info.items()
        }
        return {
            key2model[key]: kwargs[key] for key in key2model if kwargs[key] is not None
        }


class TensorWriterTask(Task[TensorWriter]):
    out_path: str
    max_shard_size: int
    safe_serialization: bool = True

    def arguments(self) -> Dict[str, Task]:
        return {}

    def execute(self, **_kwargs) -> TensorWriter:
        return TensorWriter(
            self.out_path,
            max_shard_size=self.max_shard_size,
            safe_serialization=self.safe_serialization,
        )


class SaveTensor(Task[None]):
    tensor_name: str
    tensor_task: Task
    writer_task: TensorWriterTask
    clone: bool
    optional: bool = False
    dtype: Optional[str] = None

    def arguments(self) -> Dict[str, Task]:
        return {"writer": self.writer_task, "tensor": self.tensor_task}

    def priority(self) -> int:
        return 1000

    def group_label(self) -> Optional[str]:
        return self.tensor_task.group_label()

    def execute(self, writer: TensorWriter, tensor: Optional[torch.Tensor]) -> None:
        if tensor is None:
            if not self.optional:
                raise RuntimeError(f"No value for required tensor {self.tensor_name}")
            return
        if self.dtype:
            tensor = tensor.to(dtype=dtype_from_name(self.dtype))
        writer.save_tensor(name=self.tensor_name, tensor=tensor, clone=self.clone)


class FinalizeModel(Task[None]):
    tensor_save_tasks: Tuple[Task, ...]
    writer_task: TensorWriterTask

    def arguments(self) -> Dict[str, Task]:
        return {
            "writer": self.writer_task,
            **{f"_unused_{idx}": t for idx, t in enumerate(self.tensor_save_tasks)},
        }

    def execute(self, writer: TensorWriter, **kwargs) -> None:
        writer.finalize()


class BuildStateDict(Task[Dict[str, torch.Tensor]]):
    tensors: ImmutableMap[WeightInfo, Task[torch.Tensor]]

    def arguments(self) -> Dict[str, Task]:
        return {str(wi): t for wi, t in self.tensors.items()}

    def execute(self, **kwargs) -> Dict[str, torch.Tensor]:
        return {str(wi): t for wi, t in self.tensors.items()}


class ReturnTensor(Task[torch.Tensor]):
    weight_info: WeightInfo
    tensor_task: Task[torch.Tensor]

    def arguments(self) -> Dict[str, Task]:
        return {"tensor": self.tensor_task}

    def priority(self) -> int:
        return 10000

    def group_label(self) -> Optional[str]:
        return self.tensor_task.group_label()

    def execute(self, tensor: torch.Tensor) -> torch.Tensor:
        return tensor
