from enum import Enum
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class HyperNetType(Enum):
    ACTOR = 0
    CRITIC = 1

###############################################################
# Minimal HyperNetwork Implementation (Unchanged)
###############################################################
def orthogonal_init(m, gain=1.0):
    """Simple orthogonal initialization, zeros bias."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight, gain=gain)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

import torch.nn as nn
import torch

def orthogonal_init_custom(m, gain=1.0, custom_fan_in=None, custom_fan_out=None):
    """
    Orthogonal initialization with custom fan-in and fan-out.
    """
    if isinstance(m, nn.Linear):
    
        # Transpose weights to match [in_features, out_features]
        weights = m.weight.T  # Shape: [in_features, out_features]
        
        # Reshape weights to a 2D block [custom_fan_in, custom_fan_out]
        reshaped_weights = weights.reshape(-1, custom_fan_in, custom_fan_out)
        
        # Orthogonally initialize each block
        for i in range(reshaped_weights.size(0)):
            nn.init.orthogonal_(reshaped_weights[i], gain=gain)
        
        # Reshape back to the original weight shape
        m.weight.data = reshaped_weights.reshape_as(m.weight)
        
        # # Initialize bias to zeros
        # if m.bias is not None:
        #     nn.init.zeros_(m.bias)


class LinearHypernetwork(nn.Module):
    """
    Given a list of (in_dim, out_dim) pairs, this network will
    produce W and b for each agent *once*, instead of per sample.
    """
    def __init__(
        self,
        output_dims,           # list of (in_dim, out_dim)
        hypernet_input_dim,
        hypernet_type,
        hypernet_hidden_dims=(64,),
        init_scale=math.sqrt(2),
        use_bias=True,
    ):
        super().__init__()
        self.output_dims = output_dims
        
        # e.g. hypernet_input_dim=(embedding_size,), or (num_agents,) if identity
        self.hypernet_input_dim = hypernet_input_dim[0]  
        self.hypernet_hidden_dim = hypernet_hidden_dims[0]
        self.init_scale = init_scale
        self.use_bias = use_bias

        # Initialize sub-MLPs and heads for each layer
        self.weight_mlps = nn.ModuleList()
        self.weight_heads = nn.ModuleList()
        self.bias_mlps = nn.ModuleList()
        self.bias_heads = nn.ModuleList()

        for (in_dim, out_dim) in output_dims:
            weight_dim = in_dim * out_dim
            bias_dim = out_dim

            # Tiny sub-MLP for weight
            w_layers = []
            # w_layer = nn.Linear(self.hypernet_input_dim, self.hypernet_hidden_dim, bias=self.use_bias)
            # # nn.init.orthogonal_(w_layer.weight,gain=math.sqrt(2.0))
            # # nn.init.zeros_(w_layer.bias)
            # w_layers.append(w_layer)
            # w_layers.append(nn.ReLU())
            
            # make sure head generates orthogonal weights for each agent
            w_head = nn.Linear(self.hypernet_input_dim, weight_dim, bias=self.use_bias)
            orthogonal_init_custom(w_head, gain=self.init_scale, custom_fan_in=in_dim, custom_fan_out=out_dim)
            nn.init.zeros_(w_head.bias)
            w_layers.append(w_head)
            w_mlp = nn.Sequential(*w_layers)
            
            # usually we set custom gain for final layer of actor and critic, but this is just the base,

            # Tiny sub-MLP for bias
            b_layers = []
            # b_layer = nn.Linear(self.hypernet_input_dim, self.hypernet_hidden_dim, bias=self.use_bias)
            # # nn.init.orthogonal_(b_layer.weight,math.sqrt(2.0)) 
            # # nn.init.zeros_(b_layer.bias)
            # b_layers.append(b_layer)
            # b_layers.append(nn.ReLU())
            
            # make sure bias is zero
            b_head = nn.Linear(self.hypernet_input_dim, bias_dim, bias=self.use_bias)
            nn.init.zeros_(b_head.bias)
            nn.init.zeros_(b_head.weight)
            b_layers.append(b_head)
            b_mlp = nn.Sequential(*b_layers)

            self.weight_mlps.append(w_mlp)
            self.bias_mlps.append(b_mlp)
            
            # print("weight mlp",w_mlp)
            # print("b_mlp",b_mlp)

    def forward(self, agent_embeds):
        """
        agent_embeds: shape [num_unique_agents_in_batch, hypernet_input_dim]
        
        Returns:
            weight_outs: List of [num_unique_agents_in_batch, in_dim * out_dim]
            bias_outs: List of [num_unique_agents_in_batch, out_dim]
        """
        weight_outs, bias_outs = [], []
        for i, (in_dim, out_dim) in enumerate(self.output_dims):
            # Pass through sub-MLP
            w_mlp_out = self.weight_mlps[i](agent_embeds)  # [num_unique_agents, hidden_dim]
            # w_head_out = self.weight_heads[i](w_mlp_out)   # [num_unique_agents, in_dim * out_dim]

            b_mlp_out = self.bias_mlps[i](agent_embeds)    # [num_unique_agents, hidden_dim]
            # b_head_out = self.bias_heads[i](b_mlp_out)     # [num_unique_agents, out_dim]

            weight_outs.append(w_mlp_out)
            bias_outs.append(b_mlp_out)

        return weight_outs, bias_outs
  
from torch.nn.init import  _calculate_fan_in_and_fan_out
def lecun_normal(tensor):
    if tensor is not None:
        fan_in,_ = _calculate_fan_in_and_fan_out(tensor)
        nn.init.normal_(tensor, mean=0, std=math.sqrt((1.0 / fan_in)))
        
# init for hypermarl hidden layers
def weights_init(m):
    if isinstance(m, nn.Linear):
        lecun_normal(m.weight)
        torch.nn.init.zeros_(m.bias)
          
class MLPHyperNetwork(nn.Module):
    """
    Given a list of (in_dim, out_dim) pairs, this network will
    produce W and b for each agent *once*, instead of per sample.
    """
    def __init__(
        self,
        output_dims,           # list of (in_dim, out_dim)
        hypernet_input_dim,
        hypernet_type,
        hypernet_hidden_dims=(64,),
        init_scale=math.sqrt(2),
        use_bias=True,
    ):
        super().__init__()
        self.output_dims = output_dims
        
        # e.g. hypernet_input_dim=(embedding_size,), or (num_agents,) if identity
        self.hypernet_input_dim = hypernet_input_dim[0]  
        self.hypernet_hidden_dim = hypernet_hidden_dims[0]
        self.init_scale = init_scale
        self.use_bias = use_bias

        # Initialize sub-MLPs and heads for each layer
        self.weight_mlps = nn.ModuleList()
        self.weight_heads = nn.ModuleList()
        self.bias_mlps = nn.ModuleList()
        self.bias_heads = nn.ModuleList()

        for (in_dim, out_dim) in output_dims:
            weight_dim = in_dim * out_dim
            bias_dim = out_dim

            # Tiny sub-MLP for weight
            w_layers = []
            w_layer = nn.Linear(self.hypernet_input_dim, self.hypernet_hidden_dim, bias=self.use_bias)
            # v1.8
            # weights_init(w_layer)
            # v1.9
            nn.init.orthogonal_(w_layer.weight,gain=self.init_scale)
            nn.init.zeros_(w_layer.bias)
            w_layers.append(w_layer)
            w_layers.append(nn.ReLU())
            
            # make sure head generates orthogonal weights for each agent
            w_head = nn.Linear(self.hypernet_hidden_dim, weight_dim, bias=self.use_bias)
            orthogonal_init_custom(w_head, gain=self.init_scale, custom_fan_in=in_dim, custom_fan_out=out_dim)
            nn.init.zeros_(w_head.bias)
            w_layers.append(w_head)
            w_mlp = nn.Sequential(*w_layers)
            
            # usually we set custom gain for final layer of actor and critic, but this is just the base,

            # Tiny sub-MLP for bias
            b_layers = []
            b_layer = nn.Linear(self.hypernet_input_dim, self.hypernet_hidden_dim, bias=self.use_bias)
            # v1.8
            # weights_init(b_layer)
            # v1.9
            nn.init.orthogonal_(b_layer.weight,gain=self.init_scale) 
            nn.init.zeros_(b_layer.bias)
            b_layers.append(b_layer)
            b_layers.append(nn.ReLU())
            
            # make sure bias is zero
            b_head = nn.Linear(self.hypernet_hidden_dim, bias_dim, bias=self.use_bias)
            nn.init.zeros_(b_head.bias)
            nn.init.zeros_(b_head.weight)
            b_layers.append(b_head)
            b_mlp = nn.Sequential(*b_layers)

            self.weight_mlps.append(w_mlp)
            self.bias_mlps.append(b_mlp)
            
            # print("weight mlp",w_mlp)
            # print("b_mlp",b_mlp)
            
    def forward(self, agent_embeds):
        """
        agent_embeds: shape [num_unique_agents_in_batch, hypernet_input_dim]
        
        Returns:
            weight_outs: List of [num_unique_agents_in_batch, in_dim * out_dim]
            bias_outs: List of [num_unique_agents_in_batch, out_dim]
        """
        weight_outs, bias_outs = [], []
        for i, (in_dim, out_dim) in enumerate(self.output_dims):
            # Pass through sub-MLP
            w_mlp_out = self.weight_mlps[i](agent_embeds)  # [num_unique_agents, hidden_dim]
            # w_head_out = self.weight_heads[i](w_mlp_out)   # [num_unique_agents, in_dim * out_dim]

            b_mlp_out = self.bias_mlps[i](agent_embeds)    # [num_unique_agents, hidden_dim]
            # b_head_out = self.bias_heads[i](b_mlp_out)     # [num_unique_agents, out_dim]

            weight_outs.append(w_mlp_out)
            bias_outs.append(b_mlp_out)

        return weight_outs, bias_outs

###############################################################
# Helper functions (Unchanged)
###############################################################
def get_active_func(name):
    if name.lower() == "relu":
        return nn.ReLU()
    elif name.lower() == "tanh":
        return nn.Tanh()
    elif name.lower() == "identity":
        return nn.Identity() # No-op activation function - when we don't want any activation e.g. last layer
    return nn.ReLU()  # default

def get_init_method(name):
    return None

class EmbeddingOrIdentity(nn.Module):
    """
    If use_agent_id_embeddings=True:
      - store a learnable nn.Embedding(num_agents, embedding_size) initialized orthogonally.
    Otherwise:
      - store a fixed identity matrix of size (num_agents, num_agents).
    """
    def __init__(self, num_agents: int, embedding_size: int, use_agent_id_embeddings: bool):
        super().__init__()
        self.use_agent_id_embeddings = use_agent_id_embeddings
        
        if self.use_agent_id_embeddings:
            self.embedding = nn.Embedding(num_agents, embedding_size)
            # Orthogonal initialization for embedding weights
            with torch.no_grad():
                # using default gain for relu
                nn.init.orthogonal_(self.embedding.weight,gain=math.sqrt(2.0))
        else:
            # Register a fixed identity matrix as a buffer
            self.register_buffer("identity", torch.eye(num_agents))

    def forward(self, agent_id: torch.Tensor) -> torch.Tensor:
        """
        agent_id: [batch_size], each in [0..num_agents-1].
        Returns an embedding of shape:
          - [batch_size, embedding_size] if learnable
          - [batch_size, num_agents]     if identity
        """
        if self.use_agent_id_embeddings:
            return self.embedding(agent_id.long())
        else:
            return self.identity[agent_id.long()]

###############################################################
# Replacement MLPLayer that uses HyperNetwork + agent embeddings
###############################################################
class HyperMLPLayer(nn.Module):
    def __init__(
        self, 
        input_dim, 
        hidden_sizes, 
        initialization_method, 
        activation_func, 
        num_agents, 
        hypernet_type,
        embedding_size=8, 
        use_agent_id_embeddings=True, 
        hypernet_hidden_dims=(64,),
        generate_per_agent=True,  # <-- Memory-friendly flag
        gain=None,
        use_layer_norm=False,
        use_mlp_hypernet=True,
        final_activation_func=None,
        generates_final_layer=False,
    ):
        """
        Args:
          input_dim (int): dimension of the MLP's first layer input (== obs_dim).
          hidden_sizes (list[int]): target MLP hidden sizes.
          ...
          generate_per_agent (bool): If True, generate weights per unique agent.
        """
        super(HyperMLPLayer, self).__init__()

        self.activation_func = get_active_func(activation_func)
        
        if final_activation_func is not None:
            self.final_activation_func = get_active_func(final_activation_func)
        else:
            self.final_activation_func = self.activation_func
        
        if gain is None:
            self.gain = nn.init.calculate_gain(activation_func)
        else:
            self.gain = gain

        self.agent_embeddings = EmbeddingOrIdentity(
            num_agents, embedding_size, use_agent_id_embeddings
        )

        self.num_agents = num_agents
        self.generate_per_agent = generate_per_agent
        self.use_layer_norm = use_layer_norm

        # Build the list of (in_dim, out_dim) for the target MLP
        self.output_dims = []
        prev_dim = input_dim
        for h in hidden_sizes:
            self.output_dims.append((prev_dim, h))
            prev_dim = h

        print("output dims",self.output_dims, hidden_sizes)

        # Hypernetwork that produces these layer parameters
        hnet_input_size = embedding_size if use_agent_id_embeddings else num_agents
        if use_mlp_hypernet:
            print("Using MLPHyperNetwork")
            self.hypernet = MLPHyperNetwork(
                output_dims=self.output_dims,
                hypernet_input_dim=(hnet_input_size,),
                hypernet_hidden_dims=hypernet_hidden_dims,
                init_scale=self.gain, 
                use_bias=True,
                hypernet_type=hypernet_type,
            )
        else:
            print("Using LinearHyperNetwork")
            self.hypernet = LinearHypernetwork(
                output_dims=self.output_dims,
                hypernet_input_dim=(hnet_input_size,),
                init_scale=self.gain, 
                use_bias=True,
                hypernet_type=hypernet_type,
            )

        # A LayerNorm per layer
        if self.use_layer_norm:
            self.layernorms = nn.ModuleList(
                [nn.LayerNorm(outdim) for (_, outdim) in self.output_dims]
            )
        else:
            self.layernorms = nn.ModuleList([nn.Identity() for _ in self.output_dims])
            
        if generates_final_layer:
            # no layernorm for final layer
            self.layernorms[-1] = nn.Identity()

    def forward(self, x):
        """
        x: [batch_size, obs_dim]
        # agent_id: [batch_size], each in [0..num_agents-1]
        """
        one_hot_ids = x[..., -self.num_agents:]
        agent_id = torch.argmax(one_hot_ids, dim=-1).to(x.dtype)
        # x - obs without id
        x = x[..., :-self.num_agents]
        

        # assert ids are one hot
        # Check if each row is one-hot
        # assert (one_hot_ids.sum(dim=-1) == 1).all(), "Each agent ID should have exactly one '1'."
        # assert ((one_hot_ids == 0) | (one_hot_ids == 1)).all(), "Agent ID values must be either '0' or '1'."


        if not self.generate_per_agent:
            # --------------------------------------
            # (A) Original method: produce per sample
            # --------------------------------------
            agent_embed = self.agent_embeddings(agent_id.long())  
            # => shape [batch_size, embed_dim]
            weight_heads, bias_heads = self.hypernet(agent_embed)  
            # => each list has [batch_size, in_dim*out_dim] or [batch_size, out_dim]
            
            # Apply layers
            out = x
            batch_size = x.size(0)
            for i, ((in_dim, out_dim), W_flat, b) in enumerate(zip(self.output_dims, weight_heads, bias_heads)):
                W = W_flat.view(batch_size, in_dim, out_dim)
                out = torch.bmm(out.unsqueeze(1), W).squeeze(1) + b
                out = self.activation_func(out)
                out = self.layernorms[i](out)
            return out
        
        else:
            # ---------------------------------------------------------------
            # (B) Memory-friendly method: one set of weights per *unique* agent
            # ---------------------------------------------------------------
            # 1. Identify unique agent IDs in the batch
            unique_agents, inv_idx = torch.unique(agent_id.long(), return_inverse=True)
            # print("unique_agents",unique_agents)
              # unique_agents: [num_unique_agents]
              # inv_idx: [batch_size], maps each sample to a unique agent index
            
            # 2. Get embeddings for unique agents
            agent_embed_unique = self.agent_embeddings(unique_agents)  # [num_unique_agents, embed_dim]
            
            # 3. Generate weights and biases for unique agents
            weight_heads_unique, bias_heads_unique = self.hypernet(agent_embed_unique)
              # weight_heads_unique[i]: [num_unique_agents, in_dim * out_dim]
              # bias_heads_unique[i]:   [num_unique_agents, out_dim]
            
            # Initialize output
            out = x  # [batch_size, in_dim]
            
            # Iterate through each layer
            for i, ((in_dim, out_dim), w_flat_unique, b_unique) in enumerate(zip(self.output_dims, weight_heads_unique, bias_heads_unique)):
                # Reshape weights
                W_all = w_flat_unique.view(-1, in_dim, out_dim)  # [num_unique_agents, in_dim, out_dim]
                b_all = b_unique  # [num_unique_agents, out_dim]
                
                # Initialize a tensor for the current layer's output
                # new_out = torch.zeros_like(out)  # [batch_size, out_dim]
                
                new_out = torch.zeros(x.size(0), out_dim, device=x.device, dtype=x.dtype)  # [batch_size, out_dim]

                
                # Iterate over unique agents to apply their specific weights and biases
                for agent_idx in range(W_all.size(0)):
                    # Find indices for the current agent
                    mask = (unique_agents[agent_idx] == agent_id.long())
                    
                    
                    if mask.any():
                        # Select the inputs for the current agent
                        x_selected = out[mask]  # [n_selected, in_dim]
                        
                        # Get the weights and biases for the current agent
                        W = W_all[agent_idx]    # [in_dim, out_dim]
                        b = b_all[agent_idx]    # [out_dim]
                        
                        
                        # Apply the linear transformation
                        y = torch.matmul(x_selected, W) + b  # [n_selected, out_dim]
                        
                        # Apply activation if not the final layer
                        # if i < len(self.output_dims) - 1:
                        # final layer is in act layer
                        # final layer
                        if i < len(self.output_dims) - 1:
                            y = self.activation_func(y)
                        else:
                            y = self.final_activation_func(y)
                        
                        # Apply LayerNorm
                        y = self.layernorms[i](y)
                        
                        
                        # Assign back to the output tensor
                        new_out[mask] = y
                
                # Update the output for the next layer
                out = new_out
            
            return out

###############################################################
# The MLPBase that uses our new MLPLayer
###############################################################
class MLPBase(nn.Module):
    """A MLP base module."""

    def __init__(self, args, obs_shape, hypernet_type,final_activation_func=None,use_layer_norm=True,generates_final_layer=False):
        super(MLPBase, self).__init__()
        print("HYPERMARL!!!!!!!!!!!!!!!")
        print(args)
        hypermarl = args["hypermarl"]
        print(hypermarl)

        self.use_feature_normalization = args.get("use_feature_normalization",False)
        self.initialization_method = args.get("initialization_method","orthogonal")
        self.activation_func = args.get("activation_func","relu")
        self.hidden_sizes = args["hidden_sizes"]
        self.num_agents = args["num_agents"]

        # shape without ids
        obs_dim = obs_shape[0] - self.num_agents

        if self.use_feature_normalization:
            self.feature_norm = nn.LayerNorm(obs_dim)
            
        

        self.mlp = HyperMLPLayer(
            input_dim=obs_dim,
            hidden_sizes=self.hidden_sizes,
            initialization_method=self.initialization_method,
            activation_func=self.activation_func,
            num_agents=self.num_agents,
            embedding_size=hypermarl["AGENT_ID_EMBEDDING_DIM"],
            use_agent_id_embeddings=hypermarl["USE_AGENT_ID_EMBEDDINGS"],
            hypernet_hidden_dims=hypermarl["HYPERNET_HIDDEN_DIMS"],
            generate_per_agent=True,  # <--- Turn on the memory-friendly approach
            hypernet_type=hypernet_type,
            use_layer_norm=use_layer_norm,
            use_mlp_hypernet = hypermarl.get("USE_MLP_HYPERNET", True),
            final_activation_func=final_activation_func,
            generates_final_layer=generates_final_layer
        )

    def forward(self, x):
        """
        x shape: [batch_size, obs_dim + num_agents].
                 The last 'num_agents' dims are one-hot or similar,
                 from which we derive agent_id by argmax.
        """
        obs = x[..., :-self.num_agents]
        one_hot_ids = x[..., -self.num_agents:]
        agent_id = torch.argmax(one_hot_ids, dim=-1).to(x.dtype)

        # assert ids are one hot
        # Check if each row is one-hot
        # assert (one_hot_ids.sum(dim=-1) == 1).all(), "Each agent ID should have exactly one '1'."
        # assert ((one_hot_ids == 0) | (one_hot_ids == 1)).all(), "Agent ID values must be either '0' or '1'."

        if self.use_feature_normalization:
            obs = self.feature_norm(obs)

        obs_and_id = torch.cat([obs, one_hot_ids], dim=-1)
        out = self.mlp(obs_and_id)
        return out