from models.base import LoRAParameter, SeparableConv2d
from models.separable_conv import SeparableResNet
from models.base import BasicBlock
from copy import deepcopy
from abc import ABC, abstractmethod
from models.residual_net import Conv3x3
import torch
import torch.nn as nn
import numpy as np
from models.base_vit import DynamicLoRALayer
import helper
from functools import lru_cache
from train.lora_vit_helper import ViTPEARL, split_qkv
import math


class BaseLoRAModelBuilder(ABC):
    def __init__(self, args):
        self.args = args

    @abstractmethod
    def build_lora_model(self, model, *args, **kwargs):
        """Build the LoRA model."""
        pass

    @abstractmethod
    def add_new_subnet_for_new_task(self, subnet_1, *args, **kwargs):
        """Add a new subnet for a new task"""
        pass

    @abstractmethod
    def remove_old_subnet_from_prev_task(self, subnet_1, *args, **kwargs):
        """Remove the old subnet from a previous task."""
        pass

    def log_parameters(self, model, message):
        """Log the total number of parameters in the model."""
        num_params = sum(p.numel() for p in model.parameters())
        helper.log_and_print(message + str(num_params), self.args.logger, self.args.verbose)


class SeparableConvLoRAModelBuilder(BaseLoRAModelBuilder):

    def build_lora_model(self, model, *args, **kwargs):
        helper.log_and_print("Building LoRA components!", self.args.logger, self.args.verbose)
        message = "Static SVD threshold in use!" if self.args.svd_threshold > 0 else "Dynamic SVD threshold in use!"
        helper.log_and_print(message, self.args.logger, self.args.verbose)
        param_difference = 0
        with torch.no_grad():
            for name, module in model.named_modules():
                if isinstance(module, SeparableConv2d):
                    if hasattr(module, 'lora_A_pointwise'):
                        pass
                    else:
                        module.lora_A_pointwise = nn.ModuleList()
                        module.lora_B_pointwise = nn.ModuleList()
                        module.lora_A_depthwise = nn.ModuleList()
                        module.lora_B_depthwise = nn.ModuleList()

                    # Apply SVD and LoRA components to pointwise filter

                    w1 = module.pointwise[-1].weight.flatten().detach().cpu().numpy()
                    w2 = module.pointwise[0].weight.flatten().detach().cpu().numpy()
                    threshold_dynamic = np.sum((w1 - w2) ** 2) / (np.sum(w1 ** 2) + np.sum(w2 ** 2))
                    threshold = self.args.svd_threshold if self.args.svd_threshold > 0 else threshold_dynamic

                    w_pt_delta = module.pointwise[-1].weight - module.pointwise[0].weight
                    original_params_count = w_pt_delta.numel()
                    w_pt_delta_reshaped = w_pt_delta.view(module.out_channels, module.in_channels)
                    U, S, Vt = np.linalg.svd(w_pt_delta_reshaped.detach().cpu().numpy(), full_matrices=False)
                    explained_variance = S ** 2
                    total_variance = np.sum(explained_variance)
                    cumulative_variance = np.cumsum(explained_variance) / total_variance
                    k = S.shape[0] if threshold >= 1 else np.argmax(cumulative_variance >= threshold) + 1

                    B = torch.tensor(U[:, :k] @ np.diag(S[:k]), dtype=torch.float32)
                    A = torch.tensor(np.diag(S[:k]) @ Vt[:k, :], dtype=torch.float32)
                    reduced_params_count = A.numel() + B.numel()
                    module.lora_A_pointwise.append(LoRAParameter(A))
                    module.lora_B_pointwise.append(LoRAParameter(B))

                    helper.log_and_print(f"{name} point-wise filter shape: {w_pt_delta.shape}, reshaped shape: "
                                        f"{w_pt_delta_reshaped.shape}, Total rank: {S.shape}, ", self.args.logger, self.args.verbose)
                    helper.log_and_print(f"Threshold and rank for point-wise filter in {name}: {threshold}, {k}",
                                        self.args.logger, self.args.verbose)
                    helper.log_and_print(f"A-shape: {A.shape}, B-shape: {B.shape}", self.args.logger, self.args.verbose)
                    helper.log_and_print(f"Parameter difference: {original_params_count - reduced_params_count}", self.args.logger, self.args.verbose)
                    param_difference += original_params_count - reduced_params_count

                    # Apply SVD and LoRA components to depthwise filter
                    w_dp_delta = module.depthwise[-1].weight - module.depthwise[0].weight
                    w_dp_delta_reshaped = w_dp_delta.view(module.out_channels // module.depthwise[0].groups * module.kernel_size,
                                                        module.out_channels * module.kernel_size)
                    U, S, Vt = np.linalg.svd(w_dp_delta_reshaped.detach().cpu().numpy(), full_matrices=False)

                    B = torch.tensor(U @ np.diag(S), dtype=torch.float32)
                    A = torch.tensor(np.diag(S) @ Vt, dtype=torch.float32)
                    module.lora_A_depthwise.append(LoRAParameter(A))
                    module.lora_B_depthwise.append(LoRAParameter(B))

            helper.log_and_print(f"Parameter difference for pointwise fitlers: {param_difference}", self.args.logger, self.args.verbose)
            model = model.to(model.device)
            model.reset_last_parameters(self.args)
            self.log_parameters(model, "Number of total parameters after LoRA addition: ")
            return model


    def remove_old_subnet_from_prev_task(self, subnet_1, *args, **kwargs):
        helper.log_and_print("Removing previous task subnetwork!", self.args.logger, self.args.verbose)
        layers_to_remove = {
            "pointwise",
            "depthwise"
        }

        with torch.no_grad():
            for name_1, module_1 in subnet_1.named_modules():
                if isinstance(module_1, nn.ModuleList) and any(layer in name_1 for layer in layers_to_remove) and 'lora' not in name_1:
                    if len(module_1) > 1:
                        del module_1[-1]

        model = subnet_1.to(subnet_1.device)
        self.log_parameters(model, "Number of total parameters after last task deletion: ")
        return model


    def add_new_subnet_for_new_task(self, subnet_1, *args, **kwargs):
        helper.log_and_print("Adding a new task subnetwork!", self.args.logger, self.args.verbose)
        layers_to_add = {
            "pointwise",
            "depthwise",
            "pre_bn",
            "bns",
            "fcs",
        }
        layers_to_ignore = {
            "lora"
        }

        subnet_2 = SeparableResNet(
            BasicBlock,
            [int(self.args.n_classes_per_task)] * 1,
            factor=self.args.factor,
            depth=self.args.depth,
            logger=self.args.logger,
            device=subnet_1.device,
            forward_transfer=self.args.forward_transfer
        )

        with torch.no_grad():
            if self.args.forward_transfer:
                subnet_2_dict = {deepcopy(name): deepcopy(module) for name, module in subnet_1.named_modules()}
            else:
                subnet_2_dict = {name: module for name, module in subnet_2.named_modules()}

            for name_1, module_1 in subnet_1.named_modules():
                if isinstance(module_1, nn.ModuleList) and any(layer in name_1 for layer in layers_to_add):
                    if not any(layer in name_1 for layer in layers_to_ignore):
                        layer = subnet_2_dict[name_1][0]
                        module_1.append(layer)

        model = subnet_1.to(subnet_1.device)
        self.log_parameters(model, "Number of total parameters after new task addition: ")
        return model



class ResNet18LoRAModelBuilder(BaseLoRAModelBuilder):

    def build_lora_model(self, model, *args, **kwargs):
        helper.log_and_print("Building LoRA components for resnet-18!", self.args.logger, self.args.verbose)
        message = "Static SVD threshold in use!" if self.args.svd_threshold > 0 else "Dynamic SVD threshold in use!"
        helper.log_and_print(message, self.args.logger, self.args.verbose)
        param_difference = 0
        with torch.no_grad():
            for name, module in model.named_modules():
                if isinstance(module, Conv3x3):
                    if hasattr(module, 'lora_A'):
                        pass
                    else:
                        module.lora_A = nn.ModuleList()
                        module.lora_B = nn.ModuleList()

                    # Apply SVD and LoRA components to pointwise filter

                    w1 = module.conv[-1].weight.flatten().detach().cpu().numpy()
                    w2 = module.conv[0].weight.flatten().detach().cpu().numpy()
                    threshold_dynamic = np.sum((w1 - w2) ** 2) / (np.sum(w1 ** 2) + np.sum(w2 ** 2))
                    threshold = self.args.svd_threshold if self.args.svd_threshold > 0 else threshold_dynamic

                    w_pt_delta = module.conv[-1].weight - module.conv[0].weight
                    original_params_count = w_pt_delta.numel()
                    w_pt_delta_reshaped = w_pt_delta.view(module.out_channels // module.conv[0].groups * module.kernel_size,
                                                        module.in_channels * module.kernel_size)
                    U, S, Vt = np.linalg.svd(w_pt_delta_reshaped.detach().cpu().numpy(), full_matrices=False)
                    explained_variance = S ** 2
                    total_variance = np.sum(explained_variance)
                    cumulative_variance = np.cumsum(explained_variance) / total_variance
                    k = S.shape[0] if threshold >= 1 else np.argmax(cumulative_variance >= threshold) + 1

                    B = torch.tensor(U[:, :k] @ np.diag(S[:k]), dtype=torch.float32)
                    A = torch.tensor(np.diag(S[:k]) @ Vt[:k, :], dtype=torch.float32)
                    reduced_params_count = A.numel() + B.numel()
                    module.lora_A.append(LoRAParameter(A))
                    module.lora_B.append(LoRAParameter(B))

                    helper.log_and_print(f"{name} conv filter shape: {w_pt_delta.shape}, reshaped shape: "
                                        f"{w_pt_delta_reshaped.shape}, Total rank: {S.shape}, ", self.args.logger, self.args.verbose)
                    helper.log_and_print(f"Threshold and rank for conv filter in {name}: {threshold}, {k}",
                                        self.args.logger, self.args.verbose)
                    helper.log_and_print(f"A-shape: {A.shape}, B-shape: {B.shape}", self.args.logger, self.args.verbose)
                    helper.log_and_print(f"Parameter difference: {original_params_count - reduced_params_count}", self.args.logger, self.args.verbose)
                    param_difference += original_params_count - reduced_params_count


            helper.log_and_print(f"Parameter difference for pointwise fitlers: {param_difference}", self.args.logger, self.args.verbose)
            model = model.to(model.device)
            model.reset_last_parameters(self.args)
            self.log_parameters(model, "Number of total parameters after LoRA addition: ")
            return model


    def add_new_subnet_for_new_task(self, subnet_1, *args, **kwargs):
        helper.log_and_print("Adding a new resnet-18 subnetwork!", self.args.logger, self.args.verbose)
        layers_to_ignore = {
            "lora"
        }

        with torch.no_grad():
            subnet_2_dict = {deepcopy(name): deepcopy(module) for name, module in subnet_1.named_modules()}
            for name_1, module_1 in subnet_1.named_modules():
                if isinstance(module_1, nn.ModuleList):
                    if not any(layer in name_1 for layer in layers_to_ignore):
                        layer = subnet_2_dict[name_1][0]
                        module_1.append(layer)

        model = subnet_1.to(subnet_1.device)
        self.log_parameters(model, "Number of total parameters after new resnet-18 addition: ")
        return model


    def remove_old_subnet_from_prev_task(self, subnet_1, *args, **kwargs):
        helper.log_and_print("Removing previous resnet-18 subnetwork!", self.args.logger, self.args.verbose)

        with torch.no_grad():
            for name_1, module_1 in subnet_1.named_modules():
                if isinstance(module_1, nn.ModuleList) and 'lora' not in name_1:
                    if len(module_1) > 1:
                        if isinstance(module_1[-1], nn.Conv2d):
                            del module_1[-1]

        model = subnet_1.to(subnet_1.device)
        self.log_parameters(model, "Number of total parameters after last resnet-18 subnetwork deletion: ")
        return model



class ViTLoRAModelBuilder(BaseLoRAModelBuilder):

    @lru_cache(maxsize=None)
    def cached_modules_to_replace(self, pretrained_model, target_modules):
        modules_to_replace = set()

        # First pass: collect modules to replace
        with torch.no_grad():
            for pre_name, pre_module in pretrained_model.named_modules():
                if isinstance(pre_module, nn.Linear):
                    if any(target_module in pre_name for target_module in target_modules):
                        modules_to_replace.add(pre_name)

        return list(modules_to_replace)

    def reset_lora_parameters(self, model, target_modules, task_id):
        with torch.no_grad():
            for name, module in model.named_modules():
                if any(target_module in name for target_module in target_modules):
                    if isinstance(module, DynamicLoRALayer):
                        lora_layer = module.task_lora[task_id]
                        # Extract w_a and w_b layers from the task-specific LoRA
                        w_a = lora_layer[0]
                        w_b = lora_layer[1]
                        # Reinitialize weights
                        nn.init.kaiming_uniform_(w_a.weight, a=math.sqrt(5))
                        nn.init.zeros_(w_b.weight)
                        helper.log_and_print(f"Reset LoRA parameters for {name} for task {task_id + 1}", self.args.logger,False)

    def build_lora_model(self, peft_model, finetuned_model=None, target_modules=None, task_id=None, **kwargs):
        assert finetuned_model is not None and target_modules is not None and task_id is not None
        finetuned_model_named_modules = dict(finetuned_model.named_modules())
        peft_model_named_modules = dict(peft_model.named_modules())
        param_difference = 0  # Track parameter savings

        pretrained_model = ViTPEARL(self.args).to(self.args.device)
        split_qkv(pretrained_model.image_encoder, self.args.logger)
        modules_to_replace = self.cached_modules_to_replace(pretrained_model, tuple(target_modules))
        helper.log_and_print(f"Modules to replace: {modules_to_replace}", self.args.logger, self.args.verbose)

        # Second pass: replace modules
        for pre_name in modules_to_replace:
            pre_module = dict(pretrained_model.named_modules())[pre_name]
            ft_module = finetuned_model_named_modules[pre_name]
            # Calculate task vector
            w1 = ft_module.weight.flatten().detach().cpu().numpy()
            w2 = pre_module.weight.flatten().detach().cpu().numpy()
            threshold_dynamic = np.sum((w1 - w2) ** 2) / (np.sum(w1 ** 2) + np.sum(w2 ** 2))
            threshold = self.args.svd_threshold if self.args.svd_threshold > 0 else threshold_dynamic

            w_pt_delta = ft_module.weight - pre_module.weight
            original_params_count = w_pt_delta.numel()
            U, S, Vt = np.linalg.svd(w_pt_delta.detach().cpu().numpy(), full_matrices=False)
            explained_variance = S ** 2
            total_variance = np.sum(explained_variance)
            cumulative_variance = np.cumsum(explained_variance) / total_variance
            k = S.shape[0] if threshold >= 1 else np.argmax(cumulative_variance >= threshold) + 1
            B = torch.tensor(U[:, :k] @ np.diag(S[:k]), dtype=torch.float32)
            A = torch.tensor(np.diag(S[:k]) @ Vt[:k, :], dtype=torch.float32)

            # Add task-specific parameters
            k = self.args.static_rank if self.args.static_rank > 0 else k  # Use fixed rank if specified
            peft_module = peft_model_named_modules[pre_name]
            if isinstance(peft_module, DynamicLoRALayer):
                peft_module.add_task(r=k, alpha=self.args.alpha)

            if not self.args.weight_init:
                peft_module.load_task_lora_weights(task_id, A, B)

            reduced_params_count = A.numel() + B.numel()
            helper.log_and_print(f"Threshold and rank for {pre_name}: {threshold}, {k}", self.args.logger, self.args.verbose)
            helper.log_and_print(f"Parameter difference for {pre_name}: {original_params_count - reduced_params_count}",
                                 self.args.logger, self.args.verbose)
            param_difference += original_params_count - reduced_params_count

        helper.log_and_print(f"Total parameter difference: {param_difference}", self.args.logger, self.args.verbose)

        self.reset_lora_parameters(peft_model, target_modules, task_id)
        return peft_model.to(self.args.device)

    @staticmethod
    def freeze_parameters_for_task(model, target_modules, task_id):
        """Freezes all parameters in the model except for those associated with LoRA layers
        specific to the current task (`task_id`)."""

        # Freeze all parameters initially
        for param in model.parameters():
            param.requires_grad = False

        # Unfreeze LoRA parameters for the specified task_id
        for name, module in model.named_modules():
            if any(target in name for target in target_modules):
                if isinstance(module, DynamicLoRALayer):
                    for i, task_layer in enumerate(module.task_lora):
                        for param in task_layer.parameters():
                            param.requires_grad = (i == task_id)

        # Unfreeze classifier_pool parameters for task_id
        classifier_pool = getattr(model, 'classifier_pool', None)
        for i, classifier in enumerate(classifier_pool):
            for param in classifier.parameters():
                param.requires_grad = (i == task_id)

        print(f"Parameters for task {task_id + 1} are now trainable, others are frozen.")

    @staticmethod
    def add_dynamic_lora(model, target_modules):
        modules_to_replace = []

        # Collect target module names
        for name, module in model.named_modules():
            if any(target_module in name for target_module in target_modules):
                # print(f"Found target module for LoRA: {name}")
                modules_to_replace.append(name)

        # Replace modules
        for name in modules_to_replace:
            parent_module = model
            sub_names = name.split(".")
            for sub_name in sub_names[:-1]:
                parent_module = getattr(parent_module, sub_name)

            # Replace the target module with DynamicLoRALayer
            target_name = sub_names[-1]
            target_module = getattr(parent_module, target_name)
            lora_layer = DynamicLoRALayer(target_module)
            setattr(parent_module, target_name, lora_layer)

        print("Dynamic LoRA layers added successfully.")

    def add_new_subnet_for_new_task(self, subnet_1, *args, **kwargs):
        raise NotImplementedError("This method is not needed for pre-trained model.")

    def remove_old_subnet_from_prev_task(self, subnet_1, *args, **kwargs):
        raise NotImplementedError("This method is not needed for pre-trained model.")

