import os
import json

import safetensors
import torch

from utils import logger


class TensorWriter:
    def __init__(self, out_path, max_shard_size=1000 * 1000 * 1000 * 5,
                 safe_serialization=True):
        os.makedirs(out_path, exist_ok=True)
        self.out_path = out_path
        self.max_shard_size = max_shard_size
        self.safe_serialization = safe_serialization
        self.shards_written = 0
        self.weight_map = {}
        self.current_shard = {}
        self.current_shard_size = 0
        self.total_size = 0

    def save_tensor(self, name, tensor, clone=False):
        if not tensor.is_contiguous():
            tensor = tensor.contiguous()

        tensor_size = tensor.numel() * tensor.element_size()
        if (
            self.current_shard
            and self.current_shard_size + tensor_size > self.max_shard_size
        ):
            self.flush_current_shard()

        if clone:
            tensor = tensor.clone()

        self.current_shard[name] = tensor
        self.total_size += tensor_size
        self.current_shard_size += tensor_size

    def flush_current_shard(self):
        if not self.current_shard:
            return
        logger.info("Writing shard #%d to disk", self.shards_written + 1)
        prefix, extension = self._get_name_components()
        shard_name = f"{prefix}-{self.shards_written+1}.{extension}"

        for key in self.current_shard:
            self.weight_map[key] = shard_name

        shard_path = os.path.join(self.out_path, shard_name)
        if self.safe_serialization:
            self._save_st(shard_path)
        else:
            torch.save(self.current_shard, shard_path)

        self.current_shard = {}
        self.current_shard_size = 0
        self.shards_written = self.shards_written + 1

    def finalize(self):
        self.flush_current_shard()
        logger.info("Finalizing shard names")
        prefix, extension = self._get_name_components()

        # standardize shard names to hf format
        total_shards = self.shards_written
        name_remap = {}
        for idx in range(total_shards):
            name_remap[
                f"{prefix}-{idx+1}.{extension}"
            ] = f"{prefix}-{idx+1:05d}-of-{total_shards:05d}.{extension}"

        for old_name, new_name in name_remap.items():
            os.rename(
                os.path.join(self.out_path, old_name),
                os.path.join(self.out_path, new_name),
            )

        for key in self.weight_map:
            self.weight_map[key] = name_remap[self.weight_map[key]]

        with open(
            os.path.join(self.out_path, f"{prefix}.{extension}.index.json"),
            "w",
            encoding="utf-8",
        ) as file:
            json.dump(
                {
                    "metadata": {
                        "total_size": self.total_size,
                    },
                    "weight_map": self.weight_map,
                },
                file,
            )

    def _get_name_components(self):
        if self.safe_serialization:
            return "model", "safetensors"
        return "pytorch_model", "bin"

    def _save_st(self, shard_path):
        def _do_save():
            safetensors.torch.save_file(
                self.current_shard,
                shard_path,
                metadata={"format": "pt"},
            )

        try:
            _do_save()
        except RuntimeError as e:
            if (
                len(e.args) > 0
                and isinstance(e.args[0], str)
                and "share memory" in e.args[0]
            ):
                logger.warning(
                    "Your model has duplicated tensors but the --clone-tensors "
                    "flag is not set."
                )
                self.current_shard = {
                    key: self.current_shard[key].clone() for key in self.current_shard
                }
                _do_save()
            else:
                raise