from typing import Dict, Optional

import pytest
from common import make_picollama, run_and_check_merge
from transformers import AutoConfig

from mergekit.config import (
    InputModelDefinition,
    InputSliceDefinition,
    MergeConfiguration,
    OutputSliceDefinition,
    ParameterSetting,
)
from mergekit.io import LazyTensorLoader


@pytest.fixture(scope="session")
def model_a(tmp_path_factory):
    return make_picollama(tmp_path_factory.mktemp("model_a"))


@pytest.fixture(scope="session")
def model_b(tmp_path_factory):
    return make_picollama(tmp_path_factory.mktemp("model_b"))


@pytest.fixture(scope="session")
def model_c(tmp_path_factory):
    return make_picollama(tmp_path_factory.mktemp("model_c"))


class TestBasicMerges:
    def test_gpt2_copy(self):
        config = MergeConfiguration(
            merge_method="passthrough",
            models=[InputModelDefinition(model="gpt2")],
            dtype="bfloat16",
        )
        run_and_check_merge(config)

    def test_gpt2_stack(self):
        config = MergeConfiguration(
            merge_method="passthrough",
            slices=[
                OutputSliceDefinition(
                    sources=[InputSliceDefinition(model="gpt2", layer_range=[0, 12])]
                )
            ]
            * 2,
            dtype="bfloat16",
        )

        def _check_config_layers(p: str):
            config = AutoConfig.from_pretrained(p)
            assert config.n_layer == 24

        run_and_check_merge(config, validate=_check_config_layers)

    def test_passthrough_scale(self, model_a):
        config = MergeConfiguration(
            merge_method="passthrough",
            models=[
                InputModelDefinition(
                    model=model_a,
                    parameters={
                        "scale": [
                            {"filter": "o_proj", "value": 0},
                            {"value": 1},
                        ]
                    },
                )
            ],
        )

        def _check_o_proj(p: str):
            loader = LazyTensorLoader.from_disk(p)
            saw_any = False
            for name in loader.index.tensor_paths:
                if "o_proj" in name:
                    param = loader.get_tensor(name)
                    assert (param == 0).all()
                    saw_any = True
                elif "lm_head" in name:
                    param = loader.get_tensor(name)
                    assert param.count_nonzero() > 0

            assert saw_any, "No o_proj parameters found"

        run_and_check_merge(config, validate=_check_o_proj)

    def test_linear_merge(self, model_a, model_b):
        config = self.two_model_config(model_a, model_b, merge_method="linear")
        run_and_check_merge(config)

    def test_slerp_merge(self, model_a, model_b):
        config = self.two_model_config(
            model_a, model_b, merge_method="slerp", base_model=model_a
        )
        config.parameters = {"t": 0.35}
        run_and_check_merge(config)

    def test_task_arithmetic_merge(self, model_a, model_b, model_c):
        config = self.two_model_config(
            model_a, model_b, merge_method="task_arithmetic", base_model=model_c
        )
        run_and_check_merge(config)

    def test_breadcrumbs_merge(self, model_a, model_b, model_c):
        config = self.two_model_config(
            model_a, model_b, merge_method="breadcrumbs", base_model=model_c
        )
        run_and_check_merge(config)

    def test_ties_merge(self, model_a, model_b, model_c):
        config = self.two_model_config(
            model_a,
            model_b,
            merge_method="ties",
            base_model=model_c,
            params={"density": 0.3},
        )
        run_and_check_merge(config)

    def test_dare_ties_merge(self, model_a, model_b, model_c):
        config = self.two_model_config(
            model_a,
            model_b,
            merge_method="dare_ties",
            base_model=model_c,
            params={"density": 0.66},
        )
        run_and_check_merge(config)

    def test_model_stock_merge(self, model_a, model_b, model_c):
        config = self.two_model_config(
            model_b, model_c, merge_method="model_stock", base_model=model_a
        )
        run_and_check_merge(config)

    def test_model_stock_filterwise_merge(self, model_a, model_b, model_c):
        config = self.two_model_config(
            model_b,
            model_c,
            merge_method="model_stock",
            base_model=model_a,
            params={"filter_wise": True},
        )
        run_and_check_merge(config)

    def two_model_config(
        self,
        model_a,
        model_b,
        merge_method: str,
        base_model: Optional[str] = None,
        params: Optional[Dict[str, ParameterSetting]] = None,
    ):
        config = MergeConfiguration(
            merge_method=merge_method,
            base_model=base_model,
            models=[
                InputModelDefinition(
                    model=model_a,
                    parameters={"weight": 0.6},
                ),
                InputModelDefinition(
                    model=model_b,
                    parameters={"weight": 0.4},
                ),
            ],
            dtype="bfloat16",
            parameters=params,
        )

        return config
