from copy import deepcopy
import torch
import torch.nn as nn


def get_mainnet_weights(mnet, hnet, ques_emb = None, reparam = False):
    with torch.no_grad():
        if hnet is not None:
            return torch.cat([W.flatten() for W in hnet.forward(uncond_input=ques_emb)]).flatten()
        else:
            if reparam:
                mnet = mnet.net
            return mnet.get_parameter_vector()

class AlphaReparam(nn.Module):
    def __init__(self, net, start_net=None, alpha=0.1, custom_start_net=None):  #custom_start_net is for when we don't need to do a deepcopy (refactor at some point)
        super().__init__()
        self.net = net
        self.alpha = alpha
        if custom_start_net is None:
            if start_net is None:
                self.start_net = net
            self.start_net = deepcopy(self.start_net)
        else:
            self.start_net = custom_start_net
        for param in self.start_net.parameters():
            param.requires_grad = False

    def get_parameter_vector(self):
        return self.net.get_parameter_vector()

    def forward(self, image_features, text_features, weights=None, params=None):
        return self.alpha * (self.net.forward(image_features, text_features, weights=weights, params=params)
                             - self.start_net.forward(image_features, text_features, params=params).detach()) \
               + self.start_net.forward(image_features, text_features, params=params).detach()

def alpha_reparam(config, mnet, hnet):
    mnet = AlphaReparam(mnet, alpha=config["weight_change_reparam_alpha"])
    full_model = mnet
    if config["meta_process"] == "hnet":
        full_model = hnet
    return mnet, hnet, full_model

# Fields of the config that are restored from a checkpoint if using the load from checkpoint argument
# All other fields are taken from the normal config file
checkpoint_config_fields = [
    "hypnettorch",
    "meta_process",
    "weight_change_reparam",
    "weight_change_reparam_alpha",
    "hypernet_type",
    "mainnet_hidden_dim",
    "mainnet_use_bias",
    "hypernet_hidden_dim",
    "hypernet_init",
    "hypernet_chunk_size",
    "hyperclip_model",
    "hyperclip_hidden_dim",
    "hyperclip_batch_size",
    "clip_model"
]

checkpoint_config_fields_gan = [
    "load_checkpoint_run_id",
    "sample_new_weight",
    "weight_change_reparam",
    "weight_change_reparam_alpha",
    "stop_loss",
    "inner_epochs",
    "inner_batch_size",
    "gan_algo",
    "num_d_update",
    "hypergan_batchsize",
    "hypernet_hidden_dim",
    "hypergan_noise_dim",
    "gan_optimizer",
    "gan_learning_rate",
    "gan_weight_decay",
    "gan_adam_beta1",
    "gan_adam_beta2",
    "gan_momentum",
    "gan_sgd_dampening",
    "gan_sgd_nesterov",
    "gan_rmsprop_alpha",
    "mainnet_hidden_dim",
    "mainnet_use_bias",
    "inner_optimizer",
    "inner_learning_rate",
    "inner_weight_decay",
    "inner_adam_beta1",
    "inner_adam_beta2",
    "inner_momentum",
    "inner_sgd_dampening",
    "inner_sgd_nesterov",
    "inner_rmsprop_alpha",
    "clip_model"
]

checkpoint_config_fields_hnet = [
    "hypnettorch",
    "meta_process",
    "weight_change_reparam",
    "weight_change_reparam_alpha",
    "hypernet_type",
    "mainnet_hidden_dim",
    "mainnet_use_bias",
    "hypernet_hidden_dim",
    "hypernet_init",
    "hypernet_chunk_size",
    "clip_model"
]

checkpoint_config_fields_hyperclip = [
    "hyperclip_model",
    "hyperclip_hidden_dim",
    "hyperclip_batch_size",
    "train_hyperclip_on",
    "clip_model"
]