# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

import logging
from functools import lru_cache
from typing import Any, List, Optional, Tuple

from mergekit import merge_methods
from mergekit.architecture import (
    ArchitectureInfo,
    ConfiguredArchitectureInfo,
    WeightInfo,
)
from mergekit.common import ImmutableMap, ModelReference
from mergekit.config import (
    ConfigReader,
    InputSliceDefinition,
    MergeConfiguration,
    OutputSliceDefinition,
)
from mergekit.graph import Task
from mergekit.io.tasks import (
    FinalizeModel,
    GatherTensors,
    LoaderCache,
    ReturnTensor,
    SaveTensor,
    TensorWriterTask,
)
from mergekit.merge_methods import MergeMethod
from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge
from mergekit.options import MergeOptions
from mergekit.tokenizer import BuildTokenizer


class MergePlanner:
    config: MergeConfiguration
    arch_info: ArchitectureInfo
    options: MergeOptions
    out_model_config: Any
    _method: MergeMethod
    _tensors: List[Tuple[WeightInfo, Task]]
    _current_layers: int = 0
    _tokenizer_task: Optional[BuildTokenizer] = None

    def __init__(
        self,
        config: MergeConfiguration,
        arch_info: ArchitectureInfo,
        options: MergeOptions,
        out_model_config: Any,
    ):
        self.config = config
        self.arch_info = arch_info
        self.options = options
        self.out_model_config = out_model_config
        self._method = merge_methods.get(config.merge_method)

        if config.tokenizer_source:
            self._tokenizer_task = BuildTokenizer(
                base_model=config.base_model,
                referenced_models=tuple(config.referenced_models()),
                tokenizer_source=config.tokenizer_source,
                trust_remote_code=options.trust_remote_code,
            )

    @lru_cache
    def model_arch_info(self, model: ModelReference):
        return ConfiguredArchitectureInfo(
            info=self.arch_info,
            config=model.config(trust_remote_code=self.options.trust_remote_code),
        )

    def normalize_config(self):
        base_model = self.config.base_model

        # if models to merge are specified instead of output slices, compute them
        if self.config.models:
            if self.config.slices:
                raise RuntimeError(
                    "Must specify either models to merge or output slices"
                )

            slices_in = []
            base_included = False

            for model_in in self.config.models:
                if base_model and model_in.model == base_model:
                    base_included = True

                model_info = self.model_arch_info(model_in.model)
                slices_in.append(
                    InputSliceDefinition(
                        layer_range=[0, model_info.num_layers()],
                        model=model_in.model,
                        parameters=model_in.parameters,
                    )
                )

            if base_model and not base_included:
                logging.info("Base model specified but not in input models - adding")
                base_info = self.model_arch_info(base_model)
                slices_in.append(
                    InputSliceDefinition(
                        layer_range=[0, base_info.num_layers()],
                        model=base_model,
                    )
                )

            self.config.slices = [OutputSliceDefinition(sources=slices_in)]
            self.config.models = None

    def plan_tensor(
        self,
        weight: WeightInfo,
        weights_in: List[WeightInfo],
        models: List[ModelReference],
        cfg_reader: ConfigReader,
    ):
        if weight.optional:
            # check if any input weights are present
            any_weight = False
            for model, w_in in zip(models, weights_in):
                index = LoaderCache().get(model).index
                if w_in.name in index.tensor_paths:
                    any_weight = True
                    break

            if not any_weight:
                logging.info(f"Skipping optional weight {weight.name}")
                return

        tensor_merge_method = self._method
        if self._tokenizer_task and weight.is_embed:
            tensor_merge_method = TokenizerPermutationMerge(
                tokenizer_task=self._tokenizer_task
            )

        cfg_g = cfg_reader.for_tensor(weight.name)
        global_params = {}
        for p in tensor_merge_method.parameters():
            global_params[p.name] = cfg_g.parameter(
                p.name, model=None, required=p.required, default=p.default_value
            )

        base_model = cfg_reader.base_model

        tensor_params = {}
        for model, weight_in in zip(models, weights_in):
            is_base = model == base_model
            tensor_params[model] = {}
            cfg_m = cfg_reader.for_tensor(weight_in.name)
            for p in tensor_merge_method.tensor_parameters():
                tensor_params[model][p.name] = cfg_m.parameter(
                    p.name,
                    model=model,
                    required=p.required and not is_base,
                    default=p.default_value,
                )

        gather_tensors = GatherTensors(
            weight_info=ImmutableMap(data=dict(zip(models, weights_in))),
            dtype=self.config.dtype,
            device="cuda" if self.options.read_to_gpu else None,
        )

        tensor_task = tensor_merge_method.make_task(
            output_weight=weight,
            tensors=gather_tensors,
            parameters=ImmutableMap(data=global_params),
            tensor_parameters=ImmutableMap(
                data={
                    key: ImmutableMap(data=tensor_params[key]) for key in tensor_params
                }
            ),
            base_model=base_model,
        )
        self._tensors.append((weight, tensor_task))

    def plan_layer(
        self,
        sources: List[InputSliceDefinition],
        layer_offset: int,
        t: float,
        cfg_reader: ConfigReader,
    ):
        weights_out: List[WeightInfo] = self.arch_info.layer_weights(
            index=self._current_layers,
            config=self.out_model_config,
        )
        weights_in: List[List[WeightInfo]] = [
            self.model_arch_info(s.model).layer_weights(
                index=s.layer_range[0] + layer_offset
            )
            for s in sources
        ]

        for idx, w_o in enumerate(weights_out):
            self.plan_tensor(
                weight=w_o,
                weights_in=[weights_in[j][idx] for j in range(len(weights_in))],
                models=[s.model for s in sources],
                cfg_reader=cfg_reader.with_t(t),
            )

        self._current_layers += 1

    def plan_slice(self, definition: OutputSliceDefinition):
        slice_lengths = [
            s.layer_range[1] - s.layer_range[0] for s in definition.sources
        ]
        if not all(s == slice_lengths[0] for s in slice_lengths):
            raise RuntimeError(
                "All inputs to a slice must contain the same number of layers"
            )
        num_layers = slice_lengths[0]

        cfg_reader = ConfigReader(config=self.config, slice_out=definition, t=0)
        for idx in range(num_layers):
            # compute t for interpolated gradients
            if num_layers > 1:
                t = idx / (num_layers - 1)
            else:
                t = 1

            self.plan_layer(
                definition.sources,
                layer_offset=idx,
                t=t,
                cfg_reader=cfg_reader,
            )

    def plan_to_disk(self, out_path: str) -> List[Task]:
        """Plan the merge to be streamed to disk, returning a list of tasks."""
        self._plan()

        writer_task = TensorWriterTask(
            out_path=out_path,
            max_shard_size=self.options.out_shard_size,
            safe_serialization=self.options.safe_serialization,
        )
        save_tasks = []
        for weight, tensor_task in self._tensors:
            save_tasks.append(
                SaveTensor(
                    tensor_name=weight.name,
                    tensor_task=tensor_task,
                    writer_task=writer_task,
                    clone=self.options.clone_tensors,
                    optional=weight.optional,
                    dtype=weight.force_dtype,
                )
            )
        finalize = FinalizeModel(
            tensor_save_tasks=tuple(save_tasks), writer_task=writer_task
        )

        res = save_tasks + [finalize]
        if self._tokenizer_task:
            res.append(self._tokenizer_task)
        return res

    def plan_in_memory(self) -> List[ReturnTensor]:
        """Plan the merge to be performed in memory."""
        self._plan()
        return [ReturnTensor(weight_info=w, tensor_task=t) for w, t in self._tensors]

    def _plan(self):
        self.normalize_config()
        self._tensors = []

        for weight_info in self.arch_info.pre_weights(config=self.out_model_config):
            self.plan_tensor(
                weight_info,
                [weight_info] * len(self.config.slices[0].sources),
                [s.model for s in self.config.slices[0].sources],
                ConfigReader(
                    config=self.config,
                    t=0,
                    tensor_name=weight_info.name,
                ).for_out_slice(self.config.slices[0]),
            )

        for out_slice in self.config.slices:
            self.plan_slice(out_slice)

        for weight_info in self.arch_info.post_weights(config=self.out_model_config):
            self.plan_tensor(
                weight_info,
                [weight_info] * len(self.config.slices[-1].sources),
                [s.model for s in self.config.slices[-1].sources],
                ConfigReader(
                    config=self.config,
                    t=1,
                    tensor_name=weight_info.name,
                ).for_out_slice(self.config.slices[-1]),
            )
