import os
import torch
import torch.nn as nn
import model as module_arch


def compute_regular_loss(weights):
    loss = 0.0
    for w in weights:
        w = w.reshape(-1, )
        loss = loss + torch.norm(w, 2)
    return loss


class ModifierNetwork(nn.Module):
    def __init__(self, input_dim=512, latent_dim=1024, output_dim=None, num_shared_layers=1):
        super(ModifierNetwork, self).__init__()
        self.shared_layers = nn.ModuleList([nn.Linear(input_dim, latent_dim)
                                            if i == 0 else nn.Linear(latent_dim, latent_dim)
                                            for i in range(num_shared_layers)])
        self.output_dim = output_dim

        # N branches corresponding to n linear layers for generating weight shifts
        self.branches = nn.ModuleList(
            [nn.Linear(latent_dim, torch.prod(output_dim[i])) for i in range(len(output_dim))])

    def forward(self, x):
        for layer in self.shared_layers:
            x = torch.relu(layer(x))

        outputs = [branch(x).view(list(self.output_dim[i])) for i, branch in enumerate(self.branches)]
        return outputs  # return a list

    def get_model_name(self):
        return self.__class__.__name__


class MainNetUnified(nn.Module):  # a unified net
    def __init__(self, cfg, main_net, device):
        super().__init__()
        self.main_net = main_net  # instance of our diffusion model
        self.modified_layers = cfg.main_model.args.modified_layers
        num_shared_layers = cfg.main_model.args.num_shared_layers
        input_dim = cfg.main_model.args.get("input_dim", 512)
        latent_dim = cfg.main_model.args.get("latent_dim", 1024)
        embed_dim = cfg.main_model.args.get("embed_dim", 512)  # embed dim in Transformer

        # our hypernet outputs shift [w' = w + delta_w] or offset [w' = w * (1 + delta_w)] or w' = delta_w
        self.hypernet_predict = cfg.main_model.args.get("predict", "shift")
        # we modify W_q, or W_k, or W_v?
        self.crossattn_modify = cfg.main_model.args.get("modify", "all")

        def target_modules():
            return {n: p.cuda() for n, p in self.main_net.named_modules()
                    if n in self.modified_layers}

        hooked_modules = target_modules()  # get all modules to be weight-rewrote

        weight_shapes = self.main_net.obtain_shapes(self.modified_layers)

        self.weight_shapes = []
        for layer_name in list(hooked_modules.keys()):
            if 'multihead_attn' in layer_name and self.crossattn_modify == 'kv':
                original_dim = weight_shapes[layer_name][0]
                weight_shape = torch.tensor([2 * torch.div(original_dim, 3, rounding_mode='trunc'),
                                             weight_shapes[layer_name][1]])
                # (1536 ==> 1024, 512)
                self.weight_shapes.append(weight_shape)
            else:
                self.weight_shapes.append(weight_shapes[layer_name])

        self.hypernet = ModifierNetwork(
            input_dim=input_dim,
            latent_dim=latent_dim,
            output_dim=self.weight_shapes,
            num_shared_layers=num_shared_layers,
        )

        # TODO: debug: we add L2 or L1 norm loss to constraint the weight deviation generated by our hypernet.
        self.regularization = cfg.main_model.args.regularization
        self.regular_w = cfg.main_model.args.regular_w

        # Listener personal-specific encoder
        self.person_encoder = getattr(module_arch, cfg.person_specific.type)(device, **cfg.person_specific.args)
        model_path = cfg.person_specific.checkpoint_path
        assert os.path.exists(model_path), (
            "Miss checkpoint for audio encoder: {}.".format(model_path))
        checkpoint = torch.load(model_path, map_location='cpu')
        state_dict = checkpoint['state_dict']
        self.person_encoder.load_state_dict(state_dict)

        # Load hypernet_path for inference
        if cfg.mode == 'test':
            model_path = cfg.main_model.args.get("resume")
            assert model_path is not None, "The model path should be provided."
            model_name = self.hypernet.get_model_name()
            save_path = os.path.join(cfg.trainer.saving_checkpoint_dir, model_name, model_path)
            assert os.path.exists(save_path), "Miss checkpoint for hypernet: {}.".format(save_path)
            checkpoint = torch.load(save_path, map_location=torch.device('cpu'))
            state_dict = checkpoint['state_dict']
            self.hypernet.load_state_dict(state_dict)
            print("Successfully load model for inference: {}, {}".format(model_name, model_path))

        def original_weights():
            weight_list = []
            for name, module in hooked_modules.items():
                if hasattr(module, 'weight'):
                    weight_list.append(getattr(module, 'weight'))
                elif hasattr(module, 'in_proj_weight'):
                    weight_list.append(getattr(module, 'in_proj_weight'))
                else:
                    raise ValueError("The module has either weight or in_proj_weight attribute.")
            return weight_list

        original_weights = original_weights()  # retain all original weights of hooked modules.

        # todo: necessary to discard 'weight' attribute from _parameters dictionary.
        for name, module in hooked_modules.items():
            if hasattr(module, 'weight'):
                del hooked_modules[name]._parameters['weight']
            elif hasattr(module, 'in_proj_weight'):
                del hooked_modules[name]._parameters['in_proj_weight']
            else:
                raise ValueError("The module has either weight or in_proj_weight attribute.")

        def new_forward():  # modify model weights
            for i, name in enumerate(list(hooked_modules.keys())):
                delta_w = self.kernel[i]  # delta_w generated from our hypernet

                if 'linear' in name or 'to_emotion' in name:
                    if self.hypernet_predict == 'shift':
                        hooked_modules[name].weight = original_weights[i] + delta_w
                    elif self.hypernet_predict == 'offset':
                        hooked_modules[name].weight = original_weights[i] * (self.tensor_1 + delta_w)
                    elif self.hypernet_predict == 'weight':  # directly generate the model weight
                        hooked_modules[name].weight = delta_w

                else:
                    # whether we modify W_q, or W_k, or W_v.
                    if 'multihead_attn' in name and self.crossattn_modify == 'kv':
                        # only generate delta_w for key and value in cross-attn.
                        delta_w = torch.cat((self.tensor_0, delta_w), dim=0)

                    if self.hypernet_predict == 'shift':
                        # 1. we modify  W_q & W_k & W_v
                        hooked_modules[name].in_proj_weight = original_weights[i] + delta_w
                    elif self.hypernet_predict == 'offset':
                        hooked_modules[name].in_proj_weight = original_weights[i] * (self.tensor_1 + delta_w)
                    elif self.hypernet_predict == 'weight':  # directly generate the model weight
                        hooked_modules[name].in_proj_weight = delta_w

        self.new_forward = new_forward
        self.tensor_0 = torch.zeros(size=(embed_dim, embed_dim)).to(device)  # for rewriting W_k and W_v
        self.tensor_1 = torch.tensor(1.0).to(device)  # model offset

    def forward(self, x, p):
        _, p = self.person_encoder(p)
        # extract personal-specific embedding with shape of (b, 512)
        self.kernel = self.hypernet(p)  # p[0:1]

        self.new_forward()  # rewrite model weights

        output = self.main_net(**x)

        if self.regularization:
            # TODO: debug: we add L2 or L1 norm loss to constraint the weight deviation
            #  generated by our hypernet.
            norm_loss = self.regular_w * compute_regular_loss(self.kernel)
            return output, norm_loss
        else:
            return output, torch.tensor(0.0)

# class MainNet(nn.Module):
#     def __init__(self, cfg, main_net, device):
#         super().__init__()
#         self.main_net = main_net  # our diffusion model
#         modified_layers = cfg.main_model.args.modified_layers
#         num_shared_layers = cfg.main_model.args.num_shared_layers
#         input_dim = cfg.main_model.args.input_dim
#         latent_dim = cfg.main_model.args.latent_dim
#
#         self.weight_shapes = self.main_net.obtain_shapes(modified_layers)
#
#         self.hypernet = ModifierNetwork(
#             input_dim=input_dim,
#             latent_dim=latent_dim,
#             output_dim=self.weight_shapes,
#             num_shared_layers=num_shared_layers,
#         )
#         self.regularization = cfg.main_model.args.regularization
#         self.regular_w = cfg.main_model.args.regular_w
#
#         # Listener personal-specific encoder
#         self.person_encoder = getattr(module_arch, cfg.person_specific.type)(device, **cfg.person_specific.args)
#         model_path = cfg.person_specific.checkpoint_path
#         assert os.path.exists(model_path), (
#             "Miss checkpoint for audio encoder: {}.".format(model_path))
#         checkpoint = torch.load(model_path, map_location='cpu')
#         state_dict = checkpoint['state_dict']
#         self.person_encoder.load_state_dict(state_dict)
#
#         # Load hypernet_path for inference
#         if cfg.mode == 'test':
#             model_path = cfg.main_model.args.get("resume")
#             assert model_path is not None, "The model path should be provided."
#             model_name = self.hypernet.get_model_name()
#             save_path = os.path.join(cfg.trainer.saving_checkpoint_dir, model_name, model_path)
#             assert os.path.exists(save_path), "Miss checkpoint for hypernet: {}.".format(save_path)
#             checkpoint = torch.load(save_path, map_location=torch.device('cpu'))
#             state_dict = checkpoint['state_dict']
#             self.hypernet.load_state_dict(state_dict)
#             print("Successfully load model for inference: ", model_name)
#
#         # Use to load the original weight for the current modified layer
#         self.cur_index = 0
#         self.num_modified_layers = len(modified_layers)
#
#         def target_modules():
#             return [p.cuda() for n, p in self.main_net.named_modules()
#                     if n in modified_layers]
#
#         def target_weights():
#             return [getattr(module, 'weight') for module in hooked_modules
#                     if getattr(module, 'weight') is not None]
#
#         hooked_modules = target_modules()
#
#         original_weights = target_weights()
#
#         for i in range(len(hooked_modules)):
#             del hooked_modules[i]._parameters['weight']
#
#         old_forwards = [module.forward for module in hooked_modules]
#
#         def new_forward(x):
#             hooked_modules[self.cur_index].weight = original_weights[self.cur_index] + self.kernel[self.cur_index]
#             result = old_forwards[self.cur_index](x)
#
#             # The cur_index is updated for the next modified layer
#             # so that the next modified layer can load its original weight for updating
#             if self.cur_index == self.num_modified_layers - 1:
#                 self.cur_index = 0
#             else:
#                 self.cur_index = self.cur_index + 1
#
#             return result
#
#         for module in hooked_modules:
#             module.forward = new_forward
#
#     def forward(self, x, p):
#         _, p = self.person_encoder(p)
#         # extract personal-specific embedding with shape of (b, 512)
#         self.kernel = self.hypernet(p)  # p[0:1]
#
#         output = self.main_net(**x)
#
#         if self.regularization:
#             norm_loss = self.regular_w * compute_regular_loss(self.kernel)
#             return output, norm_loss
#         else:
#             return output, torch.tensor(0.0)
