import math

import numpy as np
import torch
import torch.nn as nn
from torch.func import functional_call
from functorch import make_functional_with_buffers
from torch import vmap

def get_weight_chunk_dims(num_target_parameters: int, num_embeddings: int):
    weight_chunk_dim = math.ceil(num_target_parameters / num_embeddings)
    if weight_chunk_dim != 0:
        remainder = num_target_parameters % weight_chunk_dim
        if remainder > 0:
            diff = math.ceil(remainder / weight_chunk_dim)
            num_embeddings += diff
    return weight_chunk_dim


def count_params(module: nn.Module, input_shape=None, inputs=None):
    return sum([np.prod(p.size()) for p in module.parameters()])


class FunctionalParamVectorWrapper(nn.Module):
    """
    This wraps a module so that it takes params in the forward pass
    """

    def __init__(self, module: nn.Module):
        super(FunctionalParamVectorWrapper, self).__init__()
        self.custom_buffers = None
        param_dict = dict(module.named_parameters())
        self.target_weight_shapes = {k: param_dict[k].size() for k in param_dict}

        try:
            _functional, named_params = functional_call(module)
        except Exception:
            _functional, named_params, buffers = make_functional_with_buffers(module)
            self.custom_buffers = buffers
        self.named_params = [named_params]
        self.functional = [_functional]  # remove params from being counted

    def forward(self, param_vector: torch.Tensor, *args, **kwargs):
        params = []
        start = 0
        for p in self.named_params[0]:
            end = start + np.prod(p.size())
            params.append(param_vector[start:end].view(p.size()))
            start = end
        if self.custom_buffers is not None:
            return self.functional[0](params, self.custom_buffers, *args, **kwargs)
        return self.functional[0](params, *args, **kwargs)

def create_functional_target_network(target_network: nn.Module):
    func_model = FunctionalParamVectorWrapper(target_network)
    return func_model

class target_net(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.in_feats = [in_features, hidden_features, hidden_features]
        self.out_feats = [hidden_features, hidden_features, out_features]

        self.linear1 = nn.Linear(in_features, hidden_features)
        self.linear2 = nn.Linear(hidden_features, hidden_features)
        self.linear3 = nn.Linear(hidden_features, out_features)
        self.leakyrelu = nn.LeakyReLU()

    def forward(self, x):
        x = self.leakyrelu(self.linear1(x))
        x = self.leakyrelu(self.linear2(x))
        x = self.linear3(x)
        return x

    def get_in_dims(self):
        return self.in_feats

    def get_out_dims(self):
        return self.out_feats

    def get_submodules(self):
        return [self.linear1, self.linear2, self.linear3]

'''
target_network = target_net(in_features=64, hidden_features=256, out_features=1)  
functional_target_network = create_functional_target_network(
            target_network
)

generated_weights = torch.randn(10, 100000)  # Assuming 10 different weight vectors

# Create input data and reshape for vmap
inputs = torch.randn(10, 1, 64)  # Assuming 10 different input samples

# Define a batched version of the forward function
batched_forward = vmap(functional_target_network, in_dims=(0, 0))

# Perform batched forward pass
outputs = batched_forward(generated_weights, inputs)

print(outputs)
'''