import torch
from types import SimpleNamespace

from .lora import (
    extract_lora_ups_down,
    inject_trainable_lora_extended,
    monkeypatch_or_replace_lora_extended,
)

CLONE_OF_SIMO_KEYS = ["model", "loras", "target_replace_module", "r"]

lora_versions = dict(stable_lora="stable_lora", cloneofsimo="cloneofsimo")

lora_func_types = dict(loader="loader", injector="injector")

lora_args = dict(
    model=None,
    loras=None,
    target_replace_module=[],
    target_module=[],
    r=4,
    search_class=[torch.nn.Linear],
    dropout=0,
    lora_bias="none",
)

LoraVersions = SimpleNamespace(**lora_versions)
LoraFuncTypes = SimpleNamespace(**lora_func_types)

LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo]
LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector]


def filter_dict(_dict, keys=[]):
    if len(keys) == 0:
        assert "Keys cannot empty for filtering return dict."

    for k in keys:
        if k not in lora_args.keys():
            assert f"{k} does not exist in available LoRA arguments"

    return {k: v for k, v in _dict.items() if k in keys}


class LoraHandler(object):
    def __init__(
        self,
        version: str = LoraVersions.cloneofsimo,
        use_unet_lora: bool = False,
        use_text_lora: bool = False,
        save_for_webui: bool = False,
        only_for_webui: bool = False,
        lora_bias: str = "none",
        unet_replace_modules: list = ["UNet3DConditionModel"],
    ):
        self.version = version
        assert self.is_cloneofsimo_lora()

        self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader)
        self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector)
        self.lora_bias = lora_bias
        self.use_unet_lora = use_unet_lora
        self.use_text_lora = use_text_lora
        self.save_for_webui = save_for_webui
        self.only_for_webui = only_for_webui
        self.unet_replace_modules = unet_replace_modules
        self.use_lora = any([use_text_lora, use_unet_lora])

        if self.use_lora:
            print(f"Using LoRA Version: {self.version}")

    def is_cloneofsimo_lora(self):
        return self.version == LoraVersions.cloneofsimo

    def get_lora_func(self, func_type: str = LoraFuncTypes.loader):
        if func_type == LoraFuncTypes.loader:
            return monkeypatch_or_replace_lora_extended

        if func_type == LoraFuncTypes.injector:
            return inject_trainable_lora_extended

        assert "LoRA Version does not exist."

    def get_lora_func_args(
        self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias
    ):
        return_dict = lora_args.copy()

        return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS)
        return_dict.update(
            {
                "model": model,
                "loras": lora_path,
                "target_replace_module": replace_modules,
                "r": r,
            }
        )

        return return_dict

    def do_lora_injection(
        self,
        model,
        replace_modules,
        bias="none",
        dropout=0,
        r=4,
        lora_loader_args=None,
    ):
        REPLACE_MODULES = replace_modules

        params = None
        negation = None

        injector_args = lora_loader_args
        # import pdb;pdb.set_trace()
        params, negation = self.lora_injector(**injector_args)
        for _up, _down in extract_lora_ups_down(
            model, target_replace_module=REPLACE_MODULES
        ):

            if all(x is not None for x in [_up, _down]):
                print(
                    f"Lora successfully injected into {model.__class__.__name__}."
                )
            else:
                print("False")

            break

        return params, negation

    def add_lora_to_model(
        self, use_lora, model, replace_modules, dropout=0.0, lora_path=None, r=16
    ):

        params = None
        negation = None

        lora_loader_args = self.get_lora_func_args(
            lora_path, use_lora, model, replace_modules, r, dropout, self.lora_bias
        )

        if use_lora:
            params, negation = self.do_lora_injection(
                model,
                replace_modules,
                bias=self.lora_bias,
                lora_loader_args=lora_loader_args,
                dropout=dropout,
                r=r,
            )
        # import pdb;pdb.set_trace()
        params = model if params is None else params
        return params, negation
