from pathlib import Path

from mushroom_rl.core.logger.logger import Logger
from mushroom_rl_extensions.agents.create_agent import SetupAgent
from tqdm import trange

from .abstract_experiment import AbstractExperiment
from mushroom_rl_extensions.agents.other.setup_constant_theta_agent import SetupConstantThetaAgent
from mushroom_rl.policy.policy import Policy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def freeze_adversary_policy(policy):
    """Freezes the adversary policy's weights to make them constant."""
    for param in policy.parameters():
        param.requires_grad = False


def constrain_l2_norm(policy, max_l2_norm=1.0):
    """
    Constrain the L2 norm of the policy's weights to be <= max_l2_norm.
    This is done by scaling the weights after each update.
    """
    # Get the L2 norm of all parameters (flattened into a single tensor)
    total_norm = 0.0
    for param in policy.parameters():
        if param.requires_grad:  # Only include trainable params (skip frozen params)
            total_norm += torch.sum(param ** 2)

    total_norm = total_norm.sqrt()

    # If the total norm exceeds the max L2 norm, scale the parameters
    if total_norm > max_l2_norm:
        scale_factor = max_l2_norm / total_norm
        with torch.no_grad():  # Prevent updates while scaling
            for param in policy.parameters():
                if param.requires_grad:
                    param.mul_(scale_factor)



class FixedNeuralNetworkPolicy(Policy, nn.Module):
    """
    A fixed neural network policy where weights are manually updated and constrained
    within an L2 norm limit during training.
    """
    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super(FixedNeuralNetworkPolicy, self).__init__()

        n_input = input_shape[-1]
        n_output = output_shape[0]

        self._in = nn.Linear(n_input, n_features)
        self._h1 = nn.Linear(n_features, n_features)
        self._out = nn.Linear(n_features, n_output)

        # Initialize weights with Xavier uniform distribution
        self._initialize_weights()
         # Add this line to check which device is being used


    def _initialize_weights(self):
        """
        Initialize the network weights using Xavier uniform, 
        and then constrain their L2 norm to be <= 1.
        """
        nn.init.xavier_uniform_(self._in.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.xavier_uniform_(self._out.weight, gain=nn.init.calculate_gain("linear"))

        # Now, constrain the L2 norm of the weights
        constrain_l2_norm(self, max_l2_norm=1.0)
    # def forward(self, state):
    #     in_features = torch.squeeze(state, 1).float()

    #     features1 = F.relu(self._in(in_features))
    #     features2 = F.relu(self._h1(features1))

    #     actions = self._out(features2)

    #     return actions
    def draw_action(self, state):
        
        in_features = torch.squeeze(torch.tensor(state, dtype=torch.float32)).float()

        features1 = F.relu(self._in(in_features))
        features2 = F.relu(self._h1(features1))

        actions = self._out(features2)

        return actions



class Fix_RARLExperiment(AbstractExperiment):
    """
    Runs a robust adversarial reinforcement learning training algorithm
    Based on http://proceedings.mlr.press/v70/pinto17a/pinto17a.pdf
    """

    def load_models(self, protagonist_path, adversary_path, constant = True):
        
        mdp = self.provide_mdp()
        if type(self.new_adv_max_force) == float:
            self.update_adversary(mdp, self.new_adv_max_force)

        protagonist = SetupAgent(
            self.agent, mdp.info, idx_agent=0, use_cuda=self.use_cuda
        )
        prot_logger = Logger(
            log_name="Protagonist",
            results_dir=Path(self.results_dir) / "Logging",
            log_console=True,
            seed=self.seed,
            console_log_level=30,
        )
        protagonist.set_logger(prot_logger)
        print('protagonist path: ', protagonist_path)
        # protagonist_path = '/home/haolun/quantal-adversarial-rl/src/results/cartpole_balance_robust/algorithm___fix_rarl/first1.0_second0.5/Training/'




        # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # state_dict = torch.load(protagonist_path, map_location=device)
        # protagonist.load_state_dict(state_dict)
        # protagonist.to(device)
    
        protagonist = protagonist.load(protagonist_path)
        # if constant == False:
        #     if "quadruped" in self._domain_name:
        #         adversary = SetupAgent(
        #             self.agent,
        #             mdp.info,
        #             idx_agent=1,
        #             use_cuda=self.use_cuda,
        #             adv_n_features=64,
        #         )
        #     else:
        #         adversary = SetupAgent(
        #             self.agent, mdp.info, idx_agent=1, use_cuda=self.use_cuda
        #         )
        #     adv_logger = Logger(
        #         log_name="Adversary",
        #         results_dir=Path(self.results_dir) / "Logging",
        #         log_console=True,
        #         seed=self.seed,
        #         console_log_level=30,
        #     )
        #     adversary.set_logger(adv_logger)
        #     adversary = adversary.load(adversary_path)
            
        # else:
        #     adversary = SetupAgent("constant", mdp.info, idx_agent=1)
        
        # collect_dataset = CollectDataset()
        # callbacks = collect_dataset
        # agents = [protagonist, adversary]
    

        # core = self.provide_core(
        #         "multi-agent", agents, mdp, callback_step=callbacks
        #     )

        # core.evaluate(n_episodes=5, render=False)

        # cumulative_reward_per_episode = compute_J(core.callback_step.get(), idx_agent=0)
        # mean_reward = np.mean(cumulative_reward_per_episode)

        # print('mean reward: ', mean_reward)
        return protagonist,  mdp
    def save_agents(self, agents, results_dir, seed, full_save=False):
        protagonist_filename = "exp_" + str(seed) + "_protagonist.zip"
        agents[0].save(Path(results_dir) / protagonist_filename, full_save=full_save)

        

    def train_protagonist(self):
        def setup():
            mdp = self.provide_mdp()
            if type(self.new_adv_max_force) == float:
                self.update_adversary(mdp, self.new_adv_max_force)
            
            protagonist = SetupAgent(
                'sac', mdp.info, idx_agent=0, use_cuda=self.use_cuda
            )
            prot_logger = Logger(
                log_name="Protagonist",
                results_dir=Path(self.results_dir) / "Logging",
                log_console=True,
                seed=self.seed,
                console_log_level=30,
            )
            protagonist.set_logger(prot_logger)

          
            input_dim =  mdp.info.observation_space.shape
            output_dim = mdp.info.action_space[1].shape
        
            # Initialize 10 different fixed policies
            policies = []
            for _ in range(10):
                policy = FixedNeuralNetworkPolicy(input_dim, output_dim, n_features=256)
                freeze_adversary_policy(policy)
                policies.append(policy)

            adversary = SetupAgent("constant_theta", mdp.info, idx_agent=1,  use_cuda=self.use_cuda)

            selected_policy = np.random.choice(policies)
            freeze_adversary_policy(selected_policy)  # This will freeze the selected policy
            adversary.policy = selected_policy  

            adv_logger = Logger(
                log_name="Adversary",
                results_dir=Path(self.results_dir) / "Logging",
                log_console=True,
                seed=self.seed,
                console_log_level=30,
            )
            adversary.set_logger(adv_logger)

            agents = [protagonist, adversary]

            core = self.provide_core("multi-agent", agents, mdp)

            core.mdp.env.physics.change_first_metric(self.first_metric_value)
            core.mdp.env.physics.change_second_metric(self.second_metric_value)
            print('update metric')

            return core, prot_logger, adv_logger, policies

        core, prot_logger, adv_logger, policies = setup()
        
        # Train agents
        mean_reward_vs_adversary_progress = []
        mean_reward_without_adversary_progress = []
        for i in trange(self.n_total_iterations, leave=False):

       
            # Optimization of protagonist
            for _ in range(self.n_iterations_per_agent):
                core.learn(
                    n_steps=self.n_steps_per_iteration,
                    n_episodes=self.n_episodes_per_iteration,
                    n_steps_per_fit_per_agent=self.get_n_steps_per_fit_per_agent(
                        len(core.agent), idx_agent=0
                    ),
                    quiet=False,
                    render=self.bool_render,
                )
            # Optimization of adversary
            selected_policy = np.random.choice(policies)
            freeze_adversary_policy(selected_policy)  # This will freeze the selected policy
            core.agent[1].policy = selected_policy  

           

        self.save_agents(
            core.agent,
            Path(self.results_dir) / "Training",
            self.seed,
            full_save=False,
        )

      

        return core.agent[0], core.agent[1]

