from mushroom_rl_extensions.agents.abstract_setup import AbstractSetup
from mushroom_rl_extensions.core.multi_player_agent import MultiPlayerAgent
from mushroom_rl_extensions.policy.constant import ConstantPolicy

import torch
import torch.nn as nn
import torch.nn.functional as F
from mushroom_rl.policy.policy import Policy

def freeze_adversary_policy(policy):
    """Freezes the adversary policy's weights to make them constant."""
    for param in policy.parameters():
        param.requires_grad = False


def constrain_l2_norm(policy, max_l2_norm=1.0):
    """
    Constrain the L2 norm of the policy's weights to be <= max_l2_norm.
    This is done by scaling the weights after each update.
    """
    # Get the L2 norm of all parameters (flattened into a single tensor)
    total_norm = 0.0
    for param in policy.parameters():
        if param.requires_grad:  # Only include trainable params (skip frozen params)
            total_norm += torch.sum(param ** 2)

    total_norm = total_norm.sqrt()

    # If the total norm exceeds the max L2 norm, scale the parameters
    if total_norm > max_l2_norm:
        scale_factor = max_l2_norm / total_norm
        with torch.no_grad():  # Prevent updates while scaling
            for param in policy.parameters():
                if param.requires_grad:
                    param.mul_(scale_factor)



class FixedNeuralNetworkPolicy(Policy, nn.Module):
    """
    A fixed neural network policy where weights are manually updated and constrained
    within an L2 norm limit during training.
    """
    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super(FixedNeuralNetworkPolicy, self).__init__()

        n_input = input_shape[-1]
        n_output = output_shape[0]

        self._in = nn.Linear(n_input, n_features)
        self._h1 = nn.Linear(n_features, n_features)
        self._out = nn.Linear(n_features, n_output)

        # Initialize weights with Xavier uniform distribution
        self._initialize_weights()

    def _initialize_weights(self):
        """
        Initialize the network weights using Xavier uniform, 
        and then constrain their L2 norm to be <= 1.
        """
        nn.init.xavier_uniform_(self._in.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._out.weight, gain=nn.init.calculate_gain("linear"))

        # Now, constrain the L2 norm of the weights
        constrain_l2_norm(self, max_l2_norm=1.0)
    # def forward(self, state):
    #     in_features = torch.squeeze(state, 1).float()

    #     features1 = F.relu(self._in(in_features))
    #     features2 = F.relu(self._h1(features1))

    #     actions = self._out(features2)

    #     return actions
    def draw_action(self, state):
        
        in_features = torch.squeeze(torch.tensor(state, dtype=torch.float32)).float()

        features1 = F.relu(self._in(in_features))
        features2 = F.relu(self._h1(features1))

        actions = self._out(features2)

        return actions



class SetupConstantThetaAgent(AbstractSetup):
    """
    Creates an adversary agent with 10 different fixed neural network policies.
    The adversary randomly selects one of the 10 policies at the start of each episode.
    The selected policy is then frozen and remains constant during that episode.
    """
    N_FEATURES = 256
    
    @classmethod
    def provide_policy(cls, input_dim, output_dim):
        try:
            n_features = kwargs["adv_n_features"]
        except:
            n_features = cls.N_FEATURES
        # Initialize a fixed neural network policy for the adversary
        policy = FixedNeuralNetworkPolicy(input_dim, output_dim, n_features)
        # Freeze the policy initially (it won't be trained by backpropagation)
        freeze_adversary_policy(policy)
        return policy

    @classmethod
    def provide_agents(cls, mdp_info, idx_agent, **kwargs):
        input_dim = kwargs["input_dim"] if "input_dim" in kwargs else mdp_info.observation_space.shape
        output_dim = kwargs["output_dim"] if "output_dim" in kwargs else mdp_info.action_space[idx_agent].shape
        
        # Initialize 10 different fixed policies
        policies = []
        for _ in range(10):
            policy = cls.provide_policy(input_dim, output_dim)
            policies.append(policy)
            print('len policies: ', len(policies))
        
        return policies

    @classmethod
    def provide_agent(cls, mdp_info, idx_agent, **kwargs):
        input_dim = kwargs["input_dim"] if "input_dim" in kwargs else mdp_info.observation_space.shape
        output_dim = kwargs["output_dim"] if "output_dim" in kwargs else mdp_info.action_space[idx_agent].shape
        
        # # Get the 10 policies
        # policies = cls.provide_agents(mdp_info, idx_agent, **kwargs)
        
        # # Randomly select one of the 10 policies for the agent
        # selected_policy = np.random.choice(policies)
        selected_policy = cls.provide_policy(input_dim, output_dim)
        
        # Initialize the agent with the selected policy
        agent = MultiPlayerAgent(mdp_info, selected_policy, idx_agent)
        return agent

    @classmethod
    def update_adversary_policy(cls, agent):
        """
        Randomly select one of the 10 policies for the adversary to use in the current episode.
        Freeze the weights of the selected policy.
        """
        selected_policy = random.choice(agent.policies)  # Randomly select one policy
        
        # Update the agent's policy by assigning the selected policy
        agent.policy = selected_policy
        
        # Ensure the selected policy's weights remain frozen
        freeze_adversary_policy(agent.policy)

        return agent.policy

        