from pathlib import Path
from typing import TypeGuard, cast, overload

import torch
from peft import get_peft_model
from peft.tuners.lora import LoraLayer
from transformers import AutoModelForCausalLM, SchedulerType

from mow.common import defaults
from mow.common.data import prepare_batch_data
from mow.common.trainer import CustomTrainer
from mow.dataset import AutoChatDatasetBuilder
from mow.dataset.history import ChatHistoryMixin
from mow.modules.mow import MoW
from mow.modules.utils import iterate_lora_layers
from mow.scripts.train_mow import TrainMoWConfig
from mow.utils.types import instanceof


@overload
def __lincomb(
    target: dict[str, torch.Tensor], scores: dict[str, float]
) -> torch.Tensor: ...
@overload
def __lincomb(
    target: dict[str, torch.nn.Module] | torch.nn.ModuleDict,
    scores: dict[str, float],
) -> torch.nn.Module: ...
@overload
def __lincomb(
    target: dict[str, int] | dict[str, float], scores: dict[str, float]
) -> int | float: ...
def __lincomb(
    target: (
        dict[str, torch.Tensor]
        | dict[str, torch.nn.Module]
        | torch.nn.ModuleDict
        | dict[str, int]
        | dict[str, float]
    ),
    scores: dict[str, float],
) -> torch.Tensor | torch.nn.Module | int | float:
    def is_tensor_dict(d) -> TypeGuard[dict[str, torch.Tensor]]:
        return all(isinstance(v, torch.Tensor) for v in d.values())

    def is_linear_dict(d) -> TypeGuard[dict[str, torch.nn.Linear]]:
        return all(isinstance(v, torch.nn.Linear) for _, v in d.items())

    def is_scalar_dict(d) -> TypeGuard[dict[str, int | float]]:
        return all(isinstance(v, (int, float)) for v in d.values())

    if is_tensor_dict(target):
        first = target[next(iter(scores))]
        return sum(
            (v * scores[k] for k, v in target.items() if k in scores),
            torch.zeros_like(first),
        )

    if is_linear_dict(target):
        _, first = next(iter(target.items()))
        new_linear = torch.nn.Linear(
            first.in_features, first.out_features, bias=first.bias is not None
        )
        new_linear.weight = torch.nn.Parameter(
            __lincomb({k: v.weight for k, v in target.items()}, scores),
            requires_grad=True,
        )
        if new_linear.bias is not None:
            new_linear.bias = torch.nn.Parameter(
                __lincomb(
                    {
                        k: v.bias
                        for k, v in target.items()
                        if v.bias is not None
                    },
                    scores,
                ),
                requires_grad=True,
            )
        return new_linear

    if is_scalar_dict(target):
        return sum((v * scores[k] for k, v in target.items() if k in scores), 0)

    raise ValueError(f"Unsupported type: {type(next(iter(target.items()))[1])}")


def __copy_adapter(
    src_layer: LoraLayer,
    dst_layer: LoraLayer,
    *,
    src_adapter_name: str,
    dst_adapter_name: str,
):
    s_layer = src_layer
    d_layer = dst_layer
    s_a_name = src_adapter_name
    d_a_name = dst_adapter_name

    d_layer.r[d_a_name] = s_layer.r[s_a_name]
    d_layer.lora_alpha[d_a_name] = s_layer.lora_alpha[s_a_name]
    lora_dropout_layer = s_layer.lora_dropout[s_a_name]
    d_layer.lora_dropout.update({d_a_name: lora_dropout_layer})

    # Actual trainable parameters
    if s_layer.use_dora[s_a_name]:
        d_layer.lora_magnitude_vector[d_a_name] = s_layer.lora_magnitude_vector[
            s_a_name
        ]
    d_layer.lora_A[d_a_name] = s_layer.lora_A[s_a_name]
    d_layer.lora_B[d_a_name] = s_layer.lora_B[s_a_name]
    d_layer.lora_bias[d_a_name] = s_layer.lora_bias[s_a_name]
    d_layer.use_dora[d_a_name] = s_layer.use_dora[s_a_name]

    d_layer.scaling[d_a_name] = s_layer.scaling[s_a_name]

    d_layer._move_adapter_to_device_of_base_layer(d_a_name)

    d_layer.set_adapter([*d_layer.active_adapters, d_a_name])


def __merge_adapters(
    src_layer: LoraLayer,
    dst_layer: LoraLayer,
    *,
    routing_scores: dict[str, float],
    dst_adapter_name: str,
):
    s_layer = src_layer
    d_layer = dst_layer
    d_a_name = dst_adapter_name

    d_layer.r[d_a_name] = __lincomb(s_layer.r, routing_scores)
    d_layer.lora_alpha[d_a_name] = __lincomb(s_layer.lora_alpha, routing_scores)

    d_layer.lora_A[d_a_name] = __lincomb(s_layer.lora_A, routing_scores)
    d_layer.lora_B[d_a_name] = __lincomb(s_layer.lora_B, routing_scores)
    d_layer.lora_bias[d_a_name] = s_layer.lora_bias[
        next(iter(routing_scores.keys()))
    ]
    d_layer.use_dora[d_a_name] = s_layer.use_dora[
        next(iter(routing_scores.keys()))
    ]
    d_layer.scaling[d_a_name] = __lincomb(s_layer.scaling, routing_scores)

    d_layer._move_adapter_to_device_of_base_layer(d_a_name)
    d_layer.set_adapter(d_a_name)


def few_shot_expansion(
    config: TrainMoWConfig,
    datasets: list[str] | list[Path],
    num_samples: int,
    output_path: str | Path,
    *,
    max_steps: int = 100,
    learning_rate: float = 1e-5,
):
    assert config.train_config.output_dir is not None
    model_str = str(config.train_config.output_dir / "best")
    base_model_str = config.mow.base_model
    lora_config = config.mow.lora_config

    mow = cast(MoW, MoW.from_pretrained(model_str))
    experts = list(mow.config.expert_models.keys())

    print(f"Datasets for few-shot expansion: {datasets}")

    def prepare_dataset(dataset: str | Path):
        return (
            AutoChatDatasetBuilder.load(dataset)
            .shuffle()
            .take(num_samples)
            .prepare_graph_representation(
                sentence_transformer=mow.sentence_transformer,
                batched=True,
                batch_size=num_samples,
                desc="Preparing graph representation for few-shot expansion",
            )
            .doif(
                lambda builder: instanceof(builder, ChatHistoryMixin),
                lambda builder: builder.expand(),
            )
            .as_chat(
                tokenizer=mow.tokenizer,
                batched=False,
                desc="Preparing few-shot expansion data",
                action_only=True,
            )
            .unwrap(type="pt", output_all_columns=True)
        )

    expert_models: dict[str, str] = {}
    for i, dataset in enumerate(datasets):
        train_dataset = prepare_dataset(dataset)
        eval_dataset = prepare_dataset(dataset)

        first = train_dataset[0]
        nodes = first["nodes"]
        adjacency_matrix = first["adjacency_matrix"]
        relation_matrix = first["relation_matrix"]
        context = first["context"]

        routing_score = mow.get_routing_score(
            hidden_states=nodes.to(mow.device),
            adjacency_matrix=adjacency_matrix.to(mow.device),
            relation_matrix=relation_matrix.to(mow.device),
            context=context.to(mow.device),
            keys=experts,
        )

        new_model = AutoModelForCausalLM.from_pretrained(base_model_str)
        new_model = get_peft_model(
            new_model, lora_config or defaults.default_lora_config
        )
        new_model.to(mow.device)
        adapter_name = new_model.active_adapters[0]
        for (name, layer_idx, base_layer), (_, _, new_layer) in zip(
            mow.iterate_lora_layers(),
            iterate_lora_layers(new_model, mow.used_target_modules),
        ):
            __merge_adapters(
                src_layer=base_layer,
                dst_layer=new_layer,
                routing_scores={
                    n: float(s[layer_idx])
                    for n, s in routing_score[name].items()
                },
                dst_adapter_name=adapter_name,
            )
            for adapter in mow.shared_expert_names:
                __copy_adapter(
                    src_layer=base_layer,
                    dst_layer=new_layer,
                    src_adapter_name=adapter,
                    dst_adapter_name=adapter,
                )

        for name, module in new_model.named_modules():
            if adapter_name in name:
                module.requires_grad_(True)
            else:
                module.requires_grad_(False)

        trainer = CustomTrainer(
            model=new_model,
            tokenizer=mow.tokenizer,
            args=config.train_config.copy_with(
                output_dir=Path(output_path) / f"expansion_{i}",
                max_steps=max_steps,
                logging_steps=1,
                learning_rate=learning_rate,
                lr_scheduler_type=SchedulerType.CONSTANT,
                warmup_steps=0,
                batch_size=1,
            ),
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
        )
        trainer.train()

        for (name, layer_idx, base_layer), (_, _, new_layer) in zip(
            mow.iterate_lora_layers(),
            iterate_lora_layers(new_model, mow.used_target_modules),
        ):
            __copy_adapter(
                src_layer=new_layer,
                dst_layer=base_layer,
                src_adapter_name=adapter_name,
                dst_adapter_name=f"expanded_{i}",
            )
        for router in mow.router.values():
            sample = train_dataset.take(num_samples)
            sample = prepare_batch_data(sample)
            nodes = sample["hidden_states"]
            adjacency_matrix = sample["adjacency_matrix"]
            relation_matrix = sample["relation_matrix"]
            context = sample["context"]
            router.update_embedding_set(
                f"expanded_{i}",
                hidden_states=nodes.to(mow.device),
                adjacency_matrix=adjacency_matrix.to(mow.device),
                relation_matrix=relation_matrix.to(mow.device),
                context=context.to(mow.device),
            )
        trainer.save_model(Path(output_path) / f"expansion_{i}")
        expert_models[f"expanded_{i}"] = str(
            Path(output_path) / f"expansion_{i}"
        )

    mow.config.expert_models.update(expert_models)

    mow.save_pretrained(Path(output_path) / "mow")
    mow.config.save_pretrained(Path(output_path) / "mow")
    mow.tokenizer.save_pretrained(Path(output_path) / "mow")
