"""Implement server code. This will be short, if the server is honest."""

import torch

from .malicious_modifications import ImprintBlock, RecoveryOptimizer, SparseImprintBlock, OneShotBlock
from .malicious_modifications.parameter_utils import introspect_model, replace_module_by_instance
from .malicious_modifications.analytic_transformer_utils import (
    compute_feature_distribution,
    partially_disable_embedding,
    set_MHA,
    set_flow_backward_layer,
    disable_mha_layers,
    equalize_mha_layer,
    partially_norm_position,
    make_imprint_layer,
)
from .models.transformer_dictionary import lookup_module_names

from .aux_training import train_encoder_decoder
from .malicious_modifications.feat_decoders import generate_decoder
from .data import construct_dataloader
import logging

log = logging.getLogger(__name__)


def construct_server(model, loss_fn, cfg_case, setup, external_dataloader=None):
    """Interface function."""
    if external_dataloader is None and cfg_case.server.has_external_data:
        user_split = cfg_case.data.examples_from_split
        cfg_case.data.examples_from_split = "training" if "validation" in user_split else "validation"
        dataloader = construct_dataloader(cfg_case.data, cfg_case.impl, user_idx=None, return_full_dataset=True)
        cfg_case.data.examples_from_split = user_split
    else:
        dataloader = external_dataloader
    if cfg_case.server.name == "honest_but_curious":
        server = HonestServer(model, loss_fn, cfg_case, setup, external_dataloader=dataloader)
    elif cfg_case.server.name == "malicious_model":
        server = MaliciousModelServer(model, loss_fn, cfg_case, setup, external_dataloader=dataloader)
    elif cfg_case.server.name == "malicious_optimized_parameters":
        server = MaliciousParameterOptimizationServer(model, loss_fn, cfg_case, setup, external_dataloader=dataloader)
    elif cfg_case.server.name == "malicious_transformer_parameters":
        server = MaliciousTransformerServer(model, loss_fn, cfg_case, setup, external_dataloader=dataloader)
    else:
        raise ValueError(f"Invalid server type {cfg_case.server} given.")
    return server


class HonestServer:
    """Implement an honest server protocol.

    This class loads and selects the initial model and then sends this model to the (simulated) user.
    If multiple queries are possible, then these have to loop externally over muliple rounds via .run_protocol

    Central output: self.distribute_payload -> Dict[parameters=parameters, buffers=buffers, metadata=DataHyperparams]
    """

    THREAT = "Honest-but-curious"

    def __init__(
        self, model, loss, cfg_case, setup=dict(dtype=torch.float, device=torch.device("cpu")), external_dataloader=None
    ):
        """Inialize the server settings."""
        self.model = model
        self.model.eval()

        self.loss = loss
        self.setup = setup

        self.num_queries = cfg_case.server.num_queries

        # Data configuration has to be shared across all parties to keep preprocessing consistent:
        self.cfg_data = cfg_case.data
        self.cfg_server = cfg_case.server

        self.external_dataloader = external_dataloader

        self.secrets = dict()  # Should be nothing in here

    def __repr__(self):
        return f"""Server (of type {self.__class__.__name__}) with settings:
    Threat model: {self.THREAT}
    Number of planned queries: {self.num_queries}
    Has external/public data: {self.cfg_server.has_external_data}

    Model:
        model specification: {str(self.model.name)}
        model state: {self.cfg_server.model_state}
        {f'public buffers: {self.cfg_server.provide_public_buffers}' if len(list(self.model.buffers())) > 0 else ''}

    Secrets: {self.secrets}
    """

    def reconfigure_model(self, model_state, query_id=0):
        """Reinitialize, continue training or otherwise modify model parameters in a benign way."""
        self.model.cpu()  # References might have been used on GPU later on. Return to normal first.
        for name, module in self.model.named_modules():
            if model_state == "untrained":
                if hasattr(module, "reset_parameters"):
                    module.reset_parameters()
            elif model_state == "trained":
                pass  # model was already loaded as pretrained model
            elif model_state == "linearized":
                with torch.no_grad():
                    if isinstance(module, torch.nn.BatchNorm2d):
                        module.weight.data = module.running_var.data.clone()
                        module.bias.data = module.running_mean.data.clone() + 10
                    if isinstance(module, torch.nn.Conv2d) and hasattr(module, "bias"):
                        module.bias.data += 10
            elif model_state == "orthogonal":
                # reinit model with orthogonal parameters:
                if hasattr(module, "reset_parameters"):
                    module.reset_parameters()
                if "conv" in name or "linear" in name:
                    torch.nn.init.orthogonal_(module.weight, gain=1)

    def reset_model(self):
        pass

    def distribute_payload(self, query_id=0):
        """Server payload to send to users. These are only references to simplfiy the simulation."""

        self.reconfigure_model(self.cfg_server.model_state, query_id)
        honest_model_parameters = [p for p in self.model.parameters()]  # do not send only the generators
        if self.cfg_server.provide_public_buffers:
            honest_model_buffers = [b for b in self.model.buffers()]
        else:
            honest_model_buffers = None
        return dict(parameters=honest_model_parameters, buffers=honest_model_buffers, metadata=self.cfg_data)

    def vet_model(self, model):
        """This server is honest."""
        model = self.model  # Re-reference this everywhere
        return self.model

    def queries(self):
        return range(self.num_queries)

    def run_protocol(self, user):
        """Helper function to simulate multiple queries given a user object."""
        # Simulate a simple FL protocol
        shared_user_data = []
        payloads = []
        for query_id in self.queries():
            server_payload = self.distribute_payload(query_id)  # A malicious server can return something "fun" here
            shared_data_per_round, true_user_data = user.compute_local_updates(server_payload)
            # true_data can only be used for analysis
            payloads += [server_payload]
            shared_user_data += [shared_data_per_round]
        return shared_user_data, payloads, true_user_data


class MaliciousModelServer(HonestServer):
    """Implement a malicious server protocol.

    This server is now also able to modify the model maliciously, before sending out payloads.
    Architectural changes (via self.prepare_model) are triggered before instantation of user objects.
    These architectural changes can also be understood as a 'malicious analyst' and happen first.
    """

    THREAT = "Malicious (Analyst)"

    def __init__(
        self, model, loss, cfg_case, setup=dict(dtype=torch.float, device=torch.device("cpu")), external_dataloader=None
    ):
        """Inialize the server settings."""
        super().__init__(model, loss, cfg_case, setup, external_dataloader)
        self.model_state = "custom"  # Do not mess with model parameters no matter what init is agreed upon
        self.secrets = dict()

    def vet_model(self, model):
        """This server is not honest :>"""

        modified_model = self.model
        if self.cfg_server.model_modification.type == "ImprintBlock":
            block_fn = ImprintBlock
        elif self.cfg_server.model_modification.type == "SparseImprintBlock":
            block_fn = SparseImprintBlock
        elif self.cfg_server.model_modification.type == "OneShotBlock":
            block_fn = OneShotBlock
        else:
            raise ValueError("Unknown modification")

        modified_model, secrets = self._place_malicious_block(
            modified_model, block_fn, **self.cfg_server.model_modification
        )
        self.secrets["ImprintBlock"] = secrets

        if self.cfg_server.model_modification.position is not None:
            if self.cfg_server.model_modification.type == "SparseImprintBlock":
                block_fn = type(None)  # Linearize the full model for SparseImprint
            if self.cfg_server.model_modification.handle_preceding_layers == "identity":
                self._linearize_up_to_imprint(modified_model, block_fn)
            elif self.cfg_server.model_modification.handle_preceding_layers == "VAE":
                # Train preceding layers to be a VAE up to the target dimension
                modified_model, decoder = self.train_encoder_decoder(modified_model, block_fn)
                self.secrets["ImprintBlock"]["decoder"] = decoder
            else:
                # Otherwise do not modify the preceding layers. The attack then returns the layer input at this position directly
                pass

        # Reduce failures in later layers:
        # Note that this clashes with the VAE option!
        self._normalize_throughput(
            modified_model, gain=self.cfg_server.model_gain, trials=self.cfg_server.normalize_rounds
        )
        self.model = modified_model
        model = modified_model
        return self.model

    def _place_malicious_block(
        self, modified_model, block_fn, type, position=None, handle_preceding_layers=None, **kwargs
    ):
        """The block is placed directly before the named module. If none is given, the block is placed at the start."""
        if position is None:
            input_dim = self.cfg_data.shape[0] * self.cfg_data.shape[1] * self.cfg_data.shape[2]
            block = block_fn(input_dim, **kwargs)
            original_name = modified_model.name
            modified_model = torch.nn.Sequential(
                torch.nn.Flatten(),
                block,
                torch.nn.Unflatten(dim=1, unflattened_size=tuple(self.cfg_data.shape)),
                modified_model,
            )
            modified_model.name = original_name
            secrets = dict(weight_idx=0, bias_idx=1, shape=tuple(self.cfg_data.shape), structure=block.structure)
        else:
            block_found = False
            for name, module in modified_model.named_modules():
                if position in name:  # give some leeway for additional containers.
                    feature_shapes = introspect_model(modified_model, tuple(self.cfg_data.shape))
                    data_shape = feature_shapes[name]["shape"][1:]
                    print(f"Block inserted at feature shape {data_shape}.")
                    module_to_be_modified = module
                    block_found = True
                    break

            if not block_found:
                raise ValueError(f"Could not find module {position} in model to insert layer.")
            input_dim = torch.prod(torch.as_tensor(data_shape))
            block = block_fn(input_dim, **kwargs)

            replacement = torch.nn.Sequential(
                torch.nn.Flatten(), block, torch.nn.Unflatten(dim=1, unflattened_size=data_shape), module_to_be_modified
            )
            replace_module_by_instance(modified_model, module_to_be_modified, replacement)
            for idx, param in enumerate(modified_model.parameters()):
                if param is block.linear0.weight:
                    weight_idx = idx
                if param is block.linear0.bias:
                    bias_idx = idx
            secrets = dict(weight_idx=weight_idx, bias_idx=bias_idx, shape=data_shape, structure=block.structure)

        return modified_model, secrets

    def _linearize_up_to_imprint(self, model, block_fn):
        """This linearization option only works for a ResNet architecture."""
        first_conv_set = False  # todo: make this nice
        for name, module in self.model.named_modules():
            if isinstance(module, block_fn):
                break
            with torch.no_grad():
                if isinstance(module, torch.nn.BatchNorm2d):
                    # module.weight.data = (module.running_var.data.clone() + module.eps).sqrt()
                    # module.bias.data = module.running_mean.data.clone()
                    torch.nn.init.ones_(module.running_var)
                    torch.nn.init.ones_(module.weight)
                    torch.nn.init.zeros_(module.running_mean)
                    torch.nn.init.zeros_(module.bias)
                if isinstance(module, torch.nn.Conv2d):
                    if not first_conv_set:
                        torch.nn.init.dirac_(module.weight)
                        num_groups = module.out_channels // 3
                        module.weight.data[: num_groups * 3] = torch.cat(
                            [module.weight.data[:3, :3, :, :]] * num_groups
                        )
                        first_conv_set = True
                    else:
                        torch.nn.init.zeros_(module.weight)  # this is the resnet rule
                if "downsample.0" in name:
                    torch.nn.init.dirac_(module.weight)
                    num_groups = module.out_channels // module.in_channels
                    concat = torch.cat(
                        [module.weight.data[: module.in_channels, : module.in_channels, :, :]] * num_groups
                    )
                    module.weight.data[: num_groups * module.in_channels] = concat
                if isinstance(module, torch.nn.ReLU):
                    replace_module_by_instance(model, module, torch.nn.Identity())

    @torch.inference_mode()
    def _normalize_throughput(self, model, gain=1, trials=1, bn_modeset=False):
        """Reset throughput to be within standard mean and gain-times standard deviation."""
        features = dict()

        def named_hook(name):
            def hook_fn(module, input, output):
                features[name] = output

            return hook_fn

        if trials > 0:
            log.info(f"Normalizing model throughput with gain {gain}...")
            model.to(**self.setup)
        for round in range(trials):
            if not bn_modeset:
                for name, module in model.named_modules():
                    if isinstance(module, (torch.nn.Conv2d, torch.nn.BatchNorm2d)):
                        if isinstance(module, torch.nn.Conv2d) and module.bias is None:
                            if "downsample.0" in name:
                                module.weight.data.zero_()
                                log.info(f"Reset weight in downsample {name} to zero.")
                            continue

                        if "downsample.1" in name:
                            continue
                        hook = module.register_forward_hook(named_hook(name))
                        if self.external_dataloader is not None:
                            random_data_sample = next(iter(self.external_dataloader))[0].to(**self.setup)
                        else:
                            random_data_sample = torch.randn(
                                self.cfg_data.batch_size, *self.cfg_data.shape, **self.setup
                            )

                        model(random_data_sample)
                        std, mu = torch.std_mean(features[name])
                        log.info(f"Current mean of layer {name} is {mu.item()}, std is {std.item()} in round {round}.")

                        with torch.no_grad():
                            module.weight.data /= std / gain + 1e-8
                            module.bias.data -= mu / (std / gain + 1e-8)
                        hook.remove()
                        del features[name]
            else:
                model.train()
                if self.external_dataloader is not None:
                    random_data_sample = next(iter(self.external_dataloader))[0].to(**self.setup)
                else:
                    random_data_sample = torch.randn(self.cfg_data.batch_size, *self.cfg_data.shape, **self.setup)
                model(random_data_sample)
                model.eval()
        # Free up GPU:
        model.to(device=torch.device("cpu"))

    def train_encoder_decoder(self, modified_model, block_fn):
        """Train a compressed code (with VAE) that will then be found by the attacker."""
        if self.external_dataloader is None:
            raise ValueError("External data is necessary to train an optimal encoder/decoder structure.")

        # Unroll model up to imprint block
        # For now only the last position is allowed:
        layer_cake = list(modified_model.children())
        encoder = torch.nn.Sequential(*(layer_cake[:-1]), torch.nn.Flatten())
        decoder = generate_decoder(modified_model)
        log.info(encoder)
        log.info(decoder)
        stats = train_encoder_decoder(encoder, decoder, self.external_dataloader, self.setup)
        return modified_model, decoder


class MaliciousTransformerServer(HonestServer):
    """Implement a malicious server protocol.

    This server cannot modify the 'honest' model architecture posed by an analyst,
    but may modify the model parameters freely.
    This variation is designed to leak token information from transformer models for language modelling.
    """

    THREAT = "Malicious (Parameters)"

    def __init__(
        self, model, loss, cfg_case, setup=dict(dtype=torch.float, device=torch.device("cpu")), external_dataloader=None
    ):
        """Inialize the server settings."""
        super().__init__(model, loss, cfg_case, setup, external_dataloader)
        self.secrets = dict()

    def vet_model(self, model):
        """This server is not honest, but the model architecture stays unchanged."""
        model = self.model  # Re-reference this everywhere
        return self.model

    def reconfigure_model(self, model_state, query_id=0):
        """Reinitialize, continue training or otherwise modify model parameters."""
        super().reconfigure_model(model_state)  # Load the benign model state first

        # Figure out the names of all layers by lookup:
        # For now this is non-automated. Add a new arch to this lookup function before running it.
        lookup = lookup_module_names(self.model.name, self.model)
        hidden_dim, embedding_dim, ff_transposed = lookup["dimensions"]

        # Define "probe" function / measurement vector:
        # Probe Length is embedding_dim minus v_proportion minus skip node
        measurement_scale = self.cfg_server.param_modification.measurement_scale
        v_length = self.cfg_server.param_modification.v_length
        probe_dim = embedding_dim - v_length - 1
        weights = torch.randn(probe_dim, **self.setup)
        std, mu = torch.std_mean(weights)  # correct sample toward perfect mean and std
        probe = (weights - mu) / std / torch.as_tensor(probe_dim, **self.setup).sqrt() * measurement_scale

        measurement = torch.zeros(embedding_dim, **self.setup)
        measurement[v_length:-1] = probe

        # Reset the embedding?:
        if self.cfg_server.param_modification.reset_embedding:
            lookup["embedding"].reset_parameters()
        # Disable these parts of the embedding:
        partially_disable_embedding(lookup["embedding"], v_length)
        if hasattr(lookup["pos_encoder"], "embedding"):
            partially_disable_embedding(lookup["pos_encoder"].embedding, v_length)
            partially_norm_position(lookup["pos_encoder"].embedding, v_length)

            # Maybe later:
            # self.model.pos_encoder.embedding.weight.data[:, v_length : v_length * 4] = 0
            # embedding.weight.data[:, v_length * 4 :] = 0

        # Modify the first attention mechanism in the model:
        # Set QKV modifications in-place:
        set_MHA(
            lookup["first_attention"],
            lookup["norm_layer0"],
            lookup["pos_encoder"],
            embedding_dim,
            ff_transposed,
            self.cfg_data.shape,
            sequence_token_weight=self.cfg_server.param_modification.sequence_token_weight,
            imprint_sentence_position=self.cfg_server.param_modification.imprint_sentence_position,
            softmax_skew=self.cfg_server.param_modification.softmax_skew,
            v_length=v_length,
        )

        # Take care of second linear layers, and unused mha layers first
        set_flow_backward_layer(
            lookup["second_linear_layers"], ff_transposed=ff_transposed, eps=self.cfg_server.param_modification.eps
        )
        disable_mha_layers(lookup["unused_mha_outs"])

        if self.cfg_data.task == "masked-lm" and not self.cfg_data.disable_mlm:
            equalize_mha_layer(
                lookup["last_attention"],
                ff_transposed,
                equalize_token_weight=self.cfg_server.param_modification.equalize_token_weight,
                v_length=v_length,
            )
        else:
            if lookup["last_attention"]["mode"] == "bert":
                lookup["last_attention"]["output"].weight.data.zero_()
                lookup["last_attention"]["output"].bias.data.zero_()
            else:
                lookup["last_attention"]["out_proj_weight"].data.zero_()
                lookup["last_attention"]["out_proj_bias"].data.zero_()

        # Evaluate feature distribution of this model
        std, mu = compute_feature_distribution(self.model, lookup["first_linear_layers"][0], measurement, self)
        # And add imprint modification to the first linear layer
        make_imprint_layer(
            lookup["first_linear_layers"], measurement, mu, std, hidden_dim, embedding_dim, ff_transposed
        )
        # This should be all for the attack :>

        # We save secrets for the attack later on:
        num_layers = len(lookup["first_linear_layers"])
        tracker = 0
        weight_idx, bias_idx = [], []
        for idx, param in enumerate(self.model.parameters()):
            if tracker < num_layers and param is lookup["first_linear_layers"][tracker].weight:
                weight_idx.append(idx)
            if tracker < num_layers and param is lookup["first_linear_layers"][tracker].bias:
                bias_idx.append(idx)
                tracker += 1

        details = dict(
            weight_idx=weight_idx,
            bias_idx=bias_idx,
            data_shape=self.cfg_data.shape,
            structure="cumulative",
            v_length=v_length,
            ff_transposed=ff_transposed,
        )
        self.secrets["ImprintBlock"] = details


class MaliciousParameterOptimizationServer(HonestServer):
    """Implement a malicious server protocol.

    This server cannot modify the 'honest' model architecture posed by an analyst,
    but may modify the model parameters freely."""

    THREAT = "Malicious (Parameters)"

    def __init__(
        self, model, loss, cfg_case, setup=dict(dtype=torch.float, device=torch.device("cpu")), external_dataloader=None
    ):
        """Inialize the server settings."""
        super().__init__(model, loss, cfg_case, setup, external_dataloader)
        self.secrets = dict()

        if "optimization" in cfg_case.server.param_modification.keys():
            self.parameter_algorithm = RecoveryOptimizer(
                self.model,
                self.loss,
                self.cfg_data,
                cfg_case.impl,
                cfg_optim=cfg_case.server.param_modification["optimization"],
                setup=setup,
                external_data=external_dataloader,
            )
            self.secrets["layers"] = cfg_case.server.param_modification.optimization.layers

    def vet_model(self, model):
        """This server is not honest, but the model architecture stays normal."""
        model = self.model  # Re-reference this everywhere
        return self.model

    def reconfigure_model(self, model_state, query_id=0):
        """Reinitialize, continue training or otherwise modify model parameters."""
        super().reconfigure_model(model_state)  # Load the benign model state first

        # Then do fun things:
        self.parameter_algorithm.optimize_recovery()

