import torch
import torch.nn as nn
from collections import OrderedDict


class MLPManipulator:
    def __init__(self, model, input_dim=2):
        self.module_class = type(model) # using model.__class__ may give wrong result due to overridden value
        self.ordered_params = self._check_mlp_parameters_order(model, input_dim=input_dim)


    # check the correct order of execution of the parameters in a specified MLP model
    # and return an OrderedDict with tuple keys of these parameters in execution order:
    # OrderedDict([((layer_number, operation_type, parameter_type), parameter_name / hyper-parameters)])
    def _check_mlp_parameters_order(self, model, input_dim):
        # get all parameters from state_dict,
        # which is an OrderedDict, whose order may NOT necessarily be the actual order they're used during forward propagation
        state_dict = model.state_dict()

        # retrieve all memory IDs of model's modules
        module_id_set = set() # unordered
        # collect mapping for each module to its parameter names using ID as keys
        module_id_to_params = {} # unordered
        # list to store the modules' IDs to track their actual order of execution
        execution_order_of_module_ids = []
        # list to store the registered hooks for better removal
        hooks = []

        def _create_hook(module_id):
            def hook(module, input, output):
                if module_id in module_id_set:
                    execution_order_of_module_ids.append(module_id)
            return hook

        def _register_module_id(module_name, module):
            module_id = id(module)
            assert module_id not in module_id_set, "Duplicate module found: {} (ID: {})".format(module_name, module_id)
            module_id_set.add(module_id)
            return module_id

        def _check_param(module_name, module, attribute):
            param_name = module_name + '.' + str(attribute)
            assert param_name in state_dict, "Parameters \"{}\" not found in \"state_dict\"".format(param_name)
            param = getattr(module, attribute)
            assert param is not None, "Attribute \"{}\" not found in module \"{}\"".format(attribute, module_name)
            assert torch.allclose(state_dict[param_name], param.detach()), \
                "{} parameters mismatch: \nstate_dict: {}\nmodule: {}".format(
                    param_name, state_dict[param_name], param.detach()
                )
            return param_name
            
        # process model hierarchy based on supported parameters
        for module_name, module in model.named_children():
            for layer_name, layer in module.named_modules():
                full_module_name = module_name if layer_name == '' else module_name + '.' + layer_name

                # register hooks for all modules
                hook = layer.register_forward_hook(_create_hook(id(layer)))
                hooks.append(hook)

                # skip container modules
                if isinstance(layer, nn.Sequential): # use isinstance(layer, (nn.Sequential, nn.ModuleList, nn.ModuleDict)) for non-MLP
                    continue

                # process nn.Linear layers
                elif isinstance(layer, nn.Linear):
                    module_id = _register_module_id(full_module_name, layer)

                    # initialize parameter dict for this module
                    module_id_to_params[module_id] = {}

                    # check weight parameter
                    assert hasattr(layer, 'weight') and layer.weight is not None, "Unexpected error: the \"weight\" parameter cannot be found for \"torch.nn.Linear\" module."
                    module_id_to_params[module_id][('Linear', 'weight')] = _check_param(full_module_name, layer, 'weight')

                    # check bias parameter (if applicable)
                    assert hasattr(layer, 'bias'), "Unexpected error: the attribute \"bias\" cannot be found for \"torch.nn.Linear\" module."
                    if layer.bias is not None:
                        module_id_to_params[module_id][('Linear', 'bias')] = _check_param(full_module_name, layer, 'bias')

                # process nn.ReLU layers
                elif isinstance(layer, nn.ReLU):
                    module_id = _register_module_id(full_module_name, layer)
                    module_id_to_params[module_id] = {('ReLU', None): None} 

                # process nn.LeakyReLU layers
                elif isinstance(layer, nn.LeakyReLU):
                    module_id = _register_module_id(full_module_name, layer)
                    assert hasattr(layer, 'negative_slope'), "Unexpected error: the attribute (\"negative_slope\") cannot be found for \"torch.nn.LeakyReLU\" module."
                    module_id_to_params[module_id] = {('LeakyReLU', 'negative_slope'): layer.negative_slope}

                # process nn.BatchNorm1d layers
                elif isinstance(layer, nn.BatchNorm1d):
                    module_id = _register_module_id(full_module_name, layer)

                    # initialize parameter dict for this module
                    module_id_to_params[module_id] = OrderedDict()

                    assert hasattr(layer, 'track_running_stats'), "Unexpected error: the attribute \"track_running_stats\" cannot be found for \"torch.nn.BatchNorm1d\" module."
                    if layer.track_running_stats is True:
                        assert layer.running_mean is not None, \
                            "Unexpected error: the \"running_mean\" parameter cannot be found for \"torch.nn.BatchNorm1d\" module given attribute \"track_running_stats\" is True."
                        module_id_to_params[module_id][('BatchNorm1d', 'running_mean')] = _check_param(full_module_name, layer, 'running_mean')

                        assert layer.running_var is not None, \
                            "Unexpected error: the \"running_var\" parameter cannot be found for \"torch.nn.BatchNorm1d\" module given attribute \"track_running_stats\" is True."
                        module_id_to_params[module_id][('BatchNorm1d', 'running_var')] = _check_param(full_module_name, layer, 'running_var')

                    assert hasattr(layer, 'affine'), "Unexpected error: the attribute \"affine\" cannot be found for \"torch.nn.BatchNorm1d\" module."
                    if layer.affine is True:
                        assert layer.weight is not None, \
                            "Unexpected error: the \"weight\" parameter cannot be found for \"torch.nn.BatchNorm1d\" module given attribute \"affine\" is True."
                        module_id_to_params[module_id][('BatchNorm1d', 'weight')] = _check_param(full_module_name, layer, 'weight')

                        assert layer.bias is not None, \
                            "Unexpected error: the \"bias\" parameter cannot be found for \"torch.nn.BatchNorm1d\" module given attribute \"affine\" is True."
                        module_id_to_params[module_id][('BatchNorm1d', 'bias')] = _check_param(full_module_name, layer, 'bias')

                else:
                    raise ValueError("Unsupported layer type for {}: {}".format(full_module_name, type(layer).__name__))

        try:
            # run forward pass
            model.eval()
            with torch.no_grad():
                dummy_input = torch.randn(1, input_dim)
                model(dummy_input)
        finally:
            # always clean up hooks, even if forward pass fails
            for hook in hooks:
                hook.remove()

        # build an OrderedDict with tuple keys to store the names and parameters based on execution order:
        # OrderedDict([(layer_number, operation_type, parameter_type), parameter_name / hyper-parameters)])
        ordered_params = OrderedDict()
        layer_number = 1
        for module_id in execution_order_of_module_ids:
            assert module_id in module_id_to_params, "Unexpected error: module_id {} not found in registry".format(module_id)
            for (operation_type, parameter_type), value in module_id_to_params[module_id].items():
                ordered_params[tuple([layer_number, operation_type, parameter_type])] = value
                # each layer is considered as a complete unit of non-linear transformation that ends with the non-linear activation function
                if operation_type in ['ReLU', 'LeakyReLU']:
                    layer_number += 1

        # verify dimension consistency
        current_dim = input_dim
        for (layer_num, operation_type, parameter_type), param_name in ordered_params.items():
            if operation_type == 'Linear' and parameter_type == 'weight':
                parameter = state_dict[param_name]
                assert parameter.dim() == 2 and parameter.shape[1] == current_dim, \
                    "Dimension mismatch at {}: expected input dim {}, got {}".format(param_name, current_dim, parameter.shape[1])
                current_dim = parameter.shape[0]
            elif (operation_type == 'Linear' and parameter_type == 'bias') or operation_type == 'BatchNorm1d':
                parameter = state_dict[param_name]
                assert parameter.dim() == 1 and parameter.shape[0] == current_dim, \
                    "Dimension mismatch at {}: expected {} features, got {}".format(param_name, current_dim, parameter.shape[0])
            elif operation_type == 'ReLU':
                assert (parameter_type is None) and (param_name is None)
            elif operation_type == 'LeakyReLU':
                assert (parameter_type == 'negative_slope') and (param_name is not None)
            else:
                raise AssertionError("Invalid operation: {}".format(operation))

        return ordered_params


    # find the affine transformation of the activated linear segment given a specific input
    def find_activated_linear_segment(self, x, model):
        x = torch.tensor(x).flatten() if not isinstance(x, torch.Tensor) else x.flatten()

        assert isinstance(model, self.module_class)
        state_dict = model.state_dict()

        input_dim = x.numel()
        W_activated = torch.eye(input_dim, requires_grad=False)
        b_activated = torch.zeros(input_dim, requires_grad=False)

        for (layer_num, operation_type, parameter_type), param_name in self.ordered_params.items():
            if operation_type == 'Linear' and parameter_type == 'weight':
                W_activated = torch.matmul(state_dict[param_name], W_activated)
                b_activated = torch.matmul(state_dict[param_name], b_activated)
            elif operation_type == 'Linear' and parameter_type == 'bias':
                b_activated += state_dict[param_name]
            elif operation_type == 'BatchNorm1d' and parameter_type == 'running_mean':
                b_activated -= state_dict[param_name]
            elif operation_type == 'BatchNorm1d' and parameter_type == 'running_var':
                diag = torch.diag(torch.rsqrt(state_dict[param_name]))
                W_activated = torch.matmul(diag, W_activated)
                b_activated = torch.matmul(diag, b_activated)
            elif operation_type == 'BatchNorm1d' and parameter_type == 'weight':
                diag = torch.diag(state_dict[param_name])
                W_activated = torch.matmul(diag, W_activated)
                b_activated = torch.matmul(diag, b_activated)
            elif operation_type == 'BatchNorm1d' and parameter_type == 'bias':
                b_activated += state_dict[param_name]
            elif operation_type == 'ReLU':
                pre_activation_output = torch.matmul(W_activated, x) + b_activated
                activation = torch.where(pre_activation_output > 0, 1.0, 0.0)
                equivalent_transformation = torch.diag(activation)
                W_activated = torch.matmul(equivalent_transformation, W_activated)
                b_activated = torch.matmul(equivalent_transformation, b_activated)
            elif operation_type == 'LeakyReLU':
                assert parameter_type == 'negative_slope'
                negative_slope = param_name
                pre_activation_output = torch.matmul(W_activated, x) + b_activated
                activation = torch.where(pre_activation_output > 0, 1.0, negative_slope)[0]
                equivalent_transformation = torch.diag(activation)
                W_activated = torch.matmul(equivalent_transformation, W_activated)
                b_activated = torch.matmul(equivalent_transformation, b_activated)
            else:
                raise KeyError("Invalid key {} due to unknown error!".format(operation))

        assert torch.allclose((torch.matmul(W_activated, x) + b_activated), model(x.unsqueeze(0)).squeeze().detach(), rtol=1e-02), \
            "Unknow error: output of the activated linear segment {} does not align with the MLP model's output {} with the input {}.".format( \
                (torch.matmul(W_activated, x) + b_activated), model(x.unsqueeze(0)).squeeze().detach(), x
            )

        print("Successfully found activated linear segment!")

        return W_activated, b_activated


    # return parameter given the index of layer and the name of the parameters
    def get_parameters(self, model, layer_number, operation_type, parameter_type, requires_grad=True):
        assert isinstance(model, self.module_class), "Model class mismatch."

        parameter_name = self.ordered_params[(layer_number, operation_type, parameter_type)]

        module_names = parameter_name.split('.')

        assert parameter_type == module_names[-1], "Unexpected error: parameter_name \"{}\" does not contain parameter_type \"{}\"".format(parameter_name, parameter_type)

        module = model.get_submodule('.'.join(module_names[:-1]))

        parameter = getattr(module, module_names[-1])

        if requires_grad:
            assert parameter.requires_grad, "Extracted parameter \"{}\" does not require autograd.".format(parameter_name)

        return parameter


    # add new parameters to broaden the hidden layer of the model
    def broaden_hidden_layer(self, model, wA_new, bA_new, wB_new):
        assert isinstance(model, self.module_class), "Model class mismatch."

        state_dict = model.state_dict()
        WA = state_dict[self.ordered_params[(1, 'Linear', 'weight')]]
        bA = state_dict[self.ordered_params[(1, 'Linear', 'bias')]]
        WB = state_dict[self.ordered_params[(2, 'Linear', 'weight')]]

        assert WA.shape[0] == len(bA) and WA.shape[1] == len(wA_new) and WA.shape[0] == WB.shape[1] and WB.shape[0] == len(wB_new), \
            "Unexpected error: shape of parameters mismatch.{} - {} - {}".format(WA.shape, bA.shape, WB.shape)

        hidden_dim = WA.shape[0] + 1

        state_dict[self.ordered_params[(1, 'Linear', 'weight')]] = torch.cat([WA, wA_new.unsqueeze(0)], dim=0)
        state_dict[self.ordered_params[(1, 'Linear', 'bias')]] = torch.cat([bA, bA_new.unsqueeze(0)], dim=0)
        state_dict[self.ordered_params[(2, 'Linear', 'weight')]] = torch.cat([WB, wB_new.unsqueeze(1)], dim=1)

        for (layer_num, operation_type, parameter_type), param_name in self.ordered_params.items():
            if operation_type == 'BatchNorm1d' and parameter_type == 'running_mean':
                state_dict[param_name] = torch.cat((state_dict[param_name] , 0), dim=0)
            elif operation_type == 'BatchNorm1d' and parameter_type == 'running_var':
                state_dict[param_name] = torch.cat((state_dict[param_name] , 1), dim=0)
            elif operation_type == 'BatchNorm1d' and parameter_type == 'weight':
                state_dict[param_name] = torch.cat((state_dict[param_name] , 1), dim=0)
            elif operation_type == 'BatchNorm1d' and parameter_type == 'bias':
                state_dict[param_name] = torch.cat((state_dict[param_name] , 0), dim=0)

        new_model = self.module_class(in_size=model.in_size, out_size=model.out_size, hidden_dim=hidden_dim)
        new_model.load_state_dict(state_dict)
        return new_model

    # retrieve the older version of model before layer broaden
    def retrieve_old_model(self, model):
        assert isinstance(model, self.module_class), "Model class mismatch."
