import argparse
import json
import os

import torch
from decorator import contextmanager

from diffusers import DDIMScheduler
from peft import LoraConfig, PeftModel, get_peft_model
import copy


def freeze(module):
    for parameter in module.parameters():
        parameter.requires_grad = False


def unfreeze(module):
    for parameter in module.parameters():
        parameter.requires_grad = True


class AdaptedModel(torch.nn.Module):
    default_config = {}

    def __init__(self, pipeline, original: bool = False, config=None):

        super().__init__()

        self.model = pipeline
        self.ft_modules = {}
        self.orig_modules = {}

        # Freeze the whole model
        freeze(self.model)

        self.config = self.parse_config(config)

        # Identify the modules to fine-tune (might include patched LoRA adapters)
        self.orig_modules = self.get_modules_to_tune(self.config) if not original else {}
        self.ft_modules = {module_name: copy.deepcopy(module) for module_name, module in self.orig_modules.items()}

        self.ft_modules_list = torch.nn.ModuleList(self.ft_modules.values())
        self.orig_modules_list = torch.nn.ModuleList(self.orig_modules.values())

        # PEFT adapters already handle appropriate unfreezing, so we must not unfreeze the non-adapter layers!
        if not isinstance(self.model, PeftModel):
            unfreeze(self.ft_modules_list)

        # print number of parameters in ft_modules
        self.adapted_params_count = 0
        with self.adapted_weights_active():
            for param in self.parameters():
                self.adapted_params_count += param.numel() if param.requires_grad else 0
        print("Number of trainable parameters in ft_modules:", self.adapted_params_count)

        # Default: adapted modules are deactivated
        self.deactivate_adapted_weights()

    @classmethod
    def parse_config(cls, config):
        parsed_config = cls.default_config
        if config:
            parsed_config.update(vars(config))
        return argparse.Namespace(**parsed_config)

    def finetune(self, use_wandb: bool = False):
        raise NotImplementedError

    @contextmanager
    def adapted_weights_active(self):
        # Enable adapted weights
        self.activate_adapted_weights()
        try:
            yield self  # Provide access to the model if needed
        finally:
            # Restore the original state
            self.deactivate_adapted_weights()

    def get_modules_to_tune(self, config):

        # Validate the train_method
        SUPPORTED_TRAIN_METHODS = [
            'full', 'xattn', 'xattn-strict', 'noxattn', 'selfattn',
            'lora_full', 'lora_xattn', 'lora_xattn-strict', 'lora_noxattn', 'lora_selfattn'
        ]
        assert config.train_method in SUPPORTED_TRAIN_METHODS, (
            f"The train_method {config.train_method} is not supported! Supported "
            f"methods: {SUPPORTED_TRAIN_METHODS}")

        assert 'lora' in config.train_method or not config.lora_enabled, "LoRA is not enabled, but train_method contains 'lora'"

        # Now we can filter the (patched-with-adapters) modules to fine-tune
        # These modules will be the ones that are considered during training and are stored in the checkpoint

        filters = {
            'xattn': {
                'target_modules': ['attn2'],
                'exclude_modules': []
            },
            'xattn-strict': {
                'target_modules': ['attn2.to_k', 'attn2.to_q'],
                'exclude_modules': []
            },
            'noxattn': {
                'target_modules': [],
                'exclude_modules': ['attn2', 'time', 'out']
            },
            'selfattn': {
                'target_modules': ['attn1'],
                'exclude_modules': []
            },
            'full': {
                'target_modules': [],
                'exclude_modules': []
            },
        }

        module_names_to_tune = []
        for module_name, module in self.model.named_modules():

            # Only consider layers of the 'unet' part
            if 'unet' not in module_name:
                continue

            # Filter for the specific subset of the respective training method
            if module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:

                # Exclude modules explicitly
                excluded_modules = filters[config.train_method.replace('lora_', '')]['exclude_modules']
                excluded = False
                for exclude in excluded_modules:
                    if exclude in module_name:
                        excluded = True
                        break

                if not excluded:
                    # Filter for specific target modules (or take all)
                    target_modules = filters[config.train_method.replace('lora_', '')]['target_modules']
                    if target_modules:
                        for target in target_modules:
                            if target in module_name:
                                # Add the module to the list of modules to fine-tune
                                module_names_to_tune.append(module_name)
                    else:
                        # if no target modules are specified, use all remaining ones
                        module_names_to_tune.append(module_name)

        # If we use LoRA, we need to add additional parameters to the original model first
        # Not a perfect elegant solution, but it does the job. Later one could switch to using PEFT directly
        if 'lora' in config.train_method:
            lora_config = LoraConfig(
                lora_alpha=config.lora_alpha,
                lora_dropout=config.lora_dropout,
                r=config.lora_r,
                bias="none",
                target_modules=module_names_to_tune
            )

            # Patch the model with LoRA layers
            self.model = get_peft_model(self.model, lora_config)

            print("Patched the model with LoRA layers. The LoRAConfig is: ", lora_config.__dict__)

        modules_to_tune_dict = {module_name: get_module(self.model, module_name) for module_name in
                                module_names_to_tune}

        return modules_to_tune_dict

    @classmethod
    def from_checkpoint(cls, pipeline, model_path, adapted_model_cls=None):

        if model_path is None:
            return cls.wrap(model_path)

        try:
            config_path = os.path.join(model_path, "adapted_model_kwargs.json")
            with open(config_path, "r") as f:
                config = json.load(f)
        except FileNotFoundError:
            try:
                config_path = os.path.join(model_path, "config.json")
                with open(config_path, "r") as f:
                    config = json.load(f)
            except FileNotFoundError:
                raise FileNotFoundError(
                    "No adapted_model_kwargs.json (legacy) or config.json file found. Using empty kwargs.")

        if not adapted_model_cls:
            adapted_model_cls = cls

        adapted_model = adapted_model_cls(pipeline=pipeline, config=argparse.Namespace(**config))
        print(f"Created Adapted Model (subclass: {cls.__name__})")

        checkpoint_path = os.path.join(model_path, "checkpoint.pt")
        if isinstance(checkpoint_path, str):
            weights = torch.load(checkpoint_path, map_location=next(pipeline.parameters()).device)
            adapted_model.load_state_dict(weights)
            print(f"Loaded pretrained checkpoint from {checkpoint_path}")
        return adapted_model

    def save_checkpoint(self, exp_name, model_id, config, save_full=False, step=None):
        print("Saved Model with ID:", model_id)
        save_folder = os.path.join('models/', exp_name, model_id)

        if step is not None:
            save_folder = save_folder + f"/{step}"

        os.makedirs(save_folder, exist_ok=True)

        copy_of_me = copy.deepcopy(self)
        if save_full:
            with self.adapted_weights_active():
                copy_of_me.update_original_model()

                if isinstance(copy_of_me.model, PeftModel):
                    copy_of_me.model = copy_of_me.model.merge_and_unload()

                # avoid issues with DDIM scheduler when loading it again, just use the original one later
                copy_of_me.model.pipeline.scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4",
                                                                              subfolder="scheduler")
                copy_of_me.model.pipeline.save_pretrained(save_folder)

        torch.save(copy_of_me.state_dict(), f'{save_folder}/checkpoint.pt')

        del copy_of_me

        with open(f'{save_folder}/config.json', 'w') as fp:
            json.dump(config.__dict__, fp)

    def activate_adapted_weights(self):
        # Load the adapted weights into the model
        for key, ft_module in self.ft_modules.items():
            set_module(self.model, key, ft_module)

        # If the model has adapter layers, enable them
        if isinstance(self.model, PeftModel):
            print("PEFT: Enabling adapter layers")
            self.model.enable_adapter_layers()

    def deactivate_adapted_weights(self):
        # Load the original weights into the model
        for key, module in self.orig_modules.items():
            set_module(self.model, key, module)

        # If the model has adapter layers, disable them
        if isinstance(self.model, PeftModel):
            print("PEFT: Disabling adapter layers")
            self.model.disable_adapter_layers()

    def parameters(self, *args, **kwargs):
        parameters = []
        for ft_module in self.ft_modules.values():
            parameters.extend(list(ft_module.parameters()))
        return parameters

    def state_dict(self, *args, **kwargs):
        return {key: module.state_dict() for key, module in self.ft_modules.items()}

    def load_state_dict(self, state_dict, *args, **kwargs):
        for key, sd in state_dict.items():
            self.ft_modules[key].load_state_dict(sd)

    def update_original_model(self):
        for module_name, module in self.ft_modules.items():
            self.orig_modules[module_name].load_state_dict(module.state_dict())
        print("Updated the original model base with the adapted weights.")

    @torch.no_grad()
    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    @classmethod
    def wrap(cls, pipeline):
        return cls(pipeline, original=True)


def get_module(module, module_name):
    if isinstance(module_name, str):
        module_name = module_name.split('.')

    if len(module_name) == 0:
        return module
    else:
        module = getattr(module, module_name[0])
        return get_module(module, module_name[1:])


def set_module(module, module_name, new_module):
    if isinstance(module_name, str):
        module_name = module_name.split('.')

    if len(module_name) == 1:
        return setattr(module, module_name[0], new_module)
    else:
        module = getattr(module, module_name[0])
        return set_module(module, module_name[1:], new_module)
