from pathlib import Path

from mushroom_rl.core.logger.logger import Logger
from mushroom_rl_extensions.agents.create_agent import SetupAgent
from tqdm import trange
from mushroom_rl.utils.callbacks import CollectDataset
from .abstract_experiment import AbstractExperiment

from mushroom_rl_extensions.utils.dataset import (
    compute_J,
    compute_quadruped_success_rate,
)
import numpy as np
class RARLExperiment(AbstractExperiment):
    """
    Runs a robust adversarial reinforcement learning training algorithm
    Based on http://proceedings.mlr.press/v70/pinto17a/pinto17a.pdf
    """
    def evaluate_robustness_change(self, protagonist_path, adversary_path, n_episodes_per_metric_value, constant = True):
        """
        Evaluate return across robustness metrics.
        idx_metric: 1 for first metric, 2 for second metric, -1 for both
        """
        mdp = self.provide_mdp()
        if type(self.new_adv_max_force) == float:
            self.update_adversary(mdp, self.new_adv_max_force)

        protagonist = SetupAgent(
            self.agent, mdp.info, idx_agent=0, use_cuda=self.use_cuda
        )
        prot_logger = Logger(
            log_name="Protagonist",
            results_dir=Path(self.results_dir) / "Logging",
            log_console=True,
            seed=self.seed,
            console_log_level=30,
        )
        protagonist.set_logger(prot_logger)
        print('protagonist path: ', protagonist_path)
        protagonist = protagonist.load(protagonist_path)
        
        if constant == False:
            if "quadruped" in self._domain_name:
                adversary = SetupAgent(
                    self.agent,
                    mdp.info,
                    idx_agent=1,
                    use_cuda=self.use_cuda,
                    adv_n_features=64,
                )
            else:
                adversary = SetupAgent(
                    self.agent, mdp.info, idx_agent=1, use_cuda=self.use_cuda
                )
            adv_logger = Logger(
                log_name="Adversary",
                results_dir=Path(self.results_dir) / "Logging",
                log_console=True,
                seed=self.seed,
                console_log_level=30,
            )
            adversary.set_logger(adv_logger)
            adversary = adversary.load(adversary_path)
            print('loading adversary: ', adversary)
            
        else:
            adversary = SetupAgent("constant", mdp.info, idx_agent=1)
        
        collect_dataset = CollectDataset()
        callbacks = collect_dataset
        agents = [protagonist, adversary]
    
        
        
     
        mean_reward_per_metric_value = np.zeros(
            (self.metric_ranges[0].shape[0], self.metric_ranges[1].shape[0], 3)
        )
        count = 0
        print('metric range: ',  self.metric_ranges[0], self.metric_ranges[1])
        for first_metric_idx, first_metric_value in enumerate(
            self.metric_ranges[0]
        ):
            for second_metric_idx, second_metric_value in enumerate(
                self.metric_ranges[1]
            ):
                mdp.env.physics.change_first_metric(first_metric_value)
                mdp.env.physics.change_second_metric(second_metric_value)


            #     mean_reward_without_adversary = self.evaluate_without_adversary(
            # core.agent[0], n_episodes=5)

                # core.evaluate(
                #     n_episodes=n_episodes_per_metric_value, render=self.bool_render
                # )

                # Extract data
                # cumulative_reward_per_episode = compute_J(
                #     core.callback_step.get(), idx_agent=0
                # )
                core = self.provide_core(
                "multi-agent", agents, mdp, callback_step=callbacks
                 )

                mean_reward_without_adversary = self.evaluate_without_adversary(
                    core.agent[0], n_episodes=1)
                print('mean reward: ', mean_reward_without_adversary)
                mean_reward = np.mean(mean_reward_without_adversary)
                mean_reward_per_metric_value[first_metric_idx][
                    second_metric_idx
                ] = [
                    first_metric_value,
                    second_metric_value,
                    mean_reward,
                ]
                core.callback_step.clean()

        print(mean_reward_per_metric_value)

        return mean_reward_per_metric_value
    def load_models(self, protagonist_path, adversary_path, constant = True):
        
        mdp = self.provide_mdp()
        if type(self.new_adv_max_force) == float:
            self.update_adversary(mdp, self.new_adv_max_force)

        protagonist = SetupAgent(
            self.agent, mdp.info, idx_agent=0, use_cuda=self.use_cuda
        )
        prot_logger = Logger(
            log_name="Protagonist",
            results_dir=Path(self.results_dir) / "Logging",
            log_console=True,
            seed=self.seed,
            console_log_level=30,
        )
        protagonist.set_logger(prot_logger)
        print('protagonist path: ', protagonist_path)
        protagonist = protagonist.load(protagonist_path)
        if constant == False:
            if "quadruped" in self._domain_name:
                adversary = SetupAgent(
                    self.agent,
                    mdp.info,
                    idx_agent=1,
                    use_cuda=self.use_cuda,
                    adv_n_features=64,
                )
            else:
                adversary = SetupAgent(
                    self.agent, mdp.info, idx_agent=1, use_cuda=self.use_cuda
                )
            adv_logger = Logger(
                log_name="Adversary",
                results_dir=Path(self.results_dir) / "Logging",
                log_console=True,
                seed=self.seed,
                console_log_level=30,
            )
            adversary.set_logger(adv_logger)
            adversary = adversary.load(adversary_path)
            print('loading adversary: ', adversary)
            
        else:
            adversary = SetupAgent("constant", mdp.info, idx_agent=1)
        
        collect_dataset = CollectDataset()
        callbacks = collect_dataset
        agents = [protagonist, adversary]
    
        core = self.provide_core(
                "multi-agent", agents, mdp, callback_step=callbacks
            )

        # Evaluation of iteration
        # Vs adversary
        mean_reward_vs_adversary = self.evaluate_vs_adversary(
            core.agent[0],
            core.agent[1],
            n_episodes=5),
        
        

        # Without adversary
        mean_reward_without_adversary = self.evaluate_without_adversary(
            core.agent[0], n_episodes=5)
        

        # print('start evaluation')
        # core.evaluate(n_episodes=5, render=False)
        # print('end evaluation')
        # cumulative_reward_per_episode = compute_J(core.callback_step.get(), idx_agent=0)
        # mean_reward = np.mean(cumulative_reward_per_episode)

        return protagonist, adversary, mdp

    def train_protagonist(self):
        def setup():
            mdp = self.provide_mdp()
            if type(self.new_adv_max_force) == float:
                self.update_adversary(mdp, self.new_adv_max_force)

            protagonist = SetupAgent(
                self.agent, mdp.info, idx_agent=0, use_cuda=self.use_cuda
            )
            prot_logger = Logger(
                log_name="Protagonist",
                results_dir=Path(self.results_dir) / "Logging",
                log_console=True,
                seed=self.seed,
                console_log_level=30,
            )
            protagonist.set_logger(prot_logger)

            if "quadruped" in self._domain_name:
                adversary = SetupAgent(
                    self.agent,
                    mdp.info,
                    idx_agent=1,
                    use_cuda=self.use_cuda,
                    adv_n_features=64,
                )
            else:
                adversary = SetupAgent(
                    self.agent, mdp.info, idx_agent=1, use_cuda=self.use_cuda
                )
            adv_logger = Logger(
                log_name="Adversary",
                results_dir=Path(self.results_dir) / "Logging",
                log_console=True,
                seed=self.seed,
                console_log_level=30,
            )
            adversary.set_logger(adv_logger)

            agents = [protagonist, adversary]

            core = self.provide_core("multi-agent", agents, mdp)

            return core, prot_logger, adv_logger

        core, prot_logger, adv_logger = setup()

        # Train agents
        mean_reward_vs_adversary_progress = []
        mean_reward_without_adversary_progress = []
        for i in trange(self.n_total_iterations, leave=False):
            # Optimization of adversary
            for _ in range(self.n_iterations_per_agent):
                core.learn(
                    n_steps=self.n_steps_per_iteration,
                    n_episodes=self.n_episodes_per_iteration,
                    n_steps_per_fit_per_agent=self.get_n_steps_per_fit_per_agent(
                        len(core.agent), idx_agent=1
                    ),
                    quiet=False,
                    render=self.bool_render,
                )
            # Optimization of protagonist
            for _ in range(self.n_iterations_per_agent):
                core.learn(
                    n_steps=self.n_steps_per_iteration,
                    n_episodes=self.n_episodes_per_iteration,
                    n_steps_per_fit_per_agent=self.get_n_steps_per_fit_per_agent(
                        len(core.agent), idx_agent=0
                    ),
                    quiet=False,
                    render=self.bool_render,
                )
            # Evaluation of iteration
            # Vs adversary
            mean_reward_vs_adversary = self.evaluate_vs_adversary(
                core.agent[0],
                core.agent[1],
                n_episodes=int(self.n_evaluation_episodes / 10),
            )
            mean_reward_vs_adversary_progress.append(mean_reward_vs_adversary)
            msg_vs_adv = (
                "Experiment iteration {}:  \t Mean reward vs adversary: {}".format(
                    i, mean_reward_vs_adversary
                )
            )
            prot_logger.info(msg_vs_adv)

            # Without adversary
            mean_reward_without_adversary = self.evaluate_without_adversary(
                core.agent[0], n_episodes=int(self.n_evaluation_episodes / 10)
            )
            mean_reward_without_adversary_progress.append(mean_reward_without_adversary)
            msg_without_adv = (
                "Experiment iteration {}:  \t Mean reward without adversary: {}".format(
                    i, mean_reward_without_adversary
                )
            )
            prot_logger.info(msg_without_adv)

            # Save best agents
            prot_logger.log_best_agent(core.agent[0], mean_reward_vs_adversary)
            adv_logger.log_best_agent(core.agent[1], -mean_reward_vs_adversary)

        self.save_agents(
            core.agent,
            Path(self.results_dir) / "Training",
            self.seed,
            full_save=False,
        )

        ## Extract data
        data = {}
        # Mean reward per iteration
        data[
            "exp_" + str(self.seed) + "_mean_reward_vs_adversary_per_iteration"
        ] = mean_reward_vs_adversary_progress
        data[
            "exp_" + str(self.seed) + "_mean_reward_without_adversary_per_iteration"
        ] = mean_reward_without_adversary_progress

        # Protagonist data
        data["exp_" + str(self.seed) + "_prot_temp_per_training_step"] = core.agent[
            0
        ].temperature_data
        data[
            "exp_" + str(self.seed) + "_prot_batch_mean_entropy_per_training_step"
        ] = core.agent[0].entropy_data
        data[
            "exp_" + str(self.seed) + "_prot_actor_loss_per_training_step"
        ] = core.agent[0].actor_loss_data
        data[
            "exp_" + str(self.seed) + "_prot_critic_loss_per_training_step"
        ] = core.agent[0].critic_loss_data

        # Adversary data
        data["exp_" + str(self.seed) + "_adv_temp_per_training_step"] = core.agent[
            1
        ].temperature_data
        data[
            "exp_" + str(self.seed) + "_adv_batch_mean_entropy_per_training_step"
        ] = core.agent[1].entropy_data
        data[
            "exp_" + str(self.seed) + "_adv_actor_loss_per_training_step"
        ] = core.agent[1].actor_loss_data
        data[
            "exp_" + str(self.seed) + "_adv_critic_loss_per_training_step"
        ] = core.agent[1].critic_loss_data

        # Action data
        data[
            "exp_" + str(self.seed) + "_prot_action_norm_per_training_step"
        ] = core.action_norms[0]
        data[
            "exp_" + str(self.seed) + "_adv_action_norm_per_training_step"
        ] = core.action_norms[1]

        self.save_data(Path(self.results_dir) / "Training", **data)

        return core.agent[0], core.agent[1]
# from pathlib import Path

# from mushroom_rl.core.logger.logger import Logger
# from mushroom_rl_extensions.agents.create_agent import SetupAgent
# from tqdm import trange

# from .abstract_experiment import AbstractExperiment
# from mushroom_rl.utils.callbacks import CollectDataset
# from mushroom_rl_extensions.utils.dataset import (
#     compute_J,
#     compute_quadruped_success_rate,
# )
# import numpy as np
# import pickle
# class RARLExperiment(AbstractExperiment):
#     """
#     Runs a robust adversarial reinforcement learning training algorithm
#     Based on http://proceedings.mlr.press/v70/pinto17a/pinto17a.pdf
#     """
#     def generate_dataset(self, protagonist_path, adversary_path, constant = True):
        
#         mdp = self.provide_mdp()
#         if type(self.new_adv_max_force) == float:
#             self.update_adversary(mdp, self.new_adv_max_force)

#         protagonist = SetupAgent(
#             self.agent, mdp.info, idx_agent=0, use_cuda=self.use_cuda
#         )
#         prot_logger = Logger(
#             log_name="Protagonist",
#             results_dir=Path(self.results_dir) / "Logging",
#             log_console=True,
#             seed=self.seed,
#             console_log_level=30,
#         )
#         protagonist.set_logger(prot_logger)
#         protagonist = protagonist.load(protagonist_path)
#         if constant == False:
#             if "quadruped" in self._domain_name:
#                 adversary = SetupAgent(
#                     self.agent,
#                     mdp.info,
#                     idx_agent=1,
#                     use_cuda=self.use_cuda,
#                     adv_n_features=64,
#                 )
#             else:
#                 adversary = SetupAgent(
#                     self.agent, mdp.info, idx_agent=1, use_cuda=self.use_cuda
#                 )
#             adv_logger = Logger(
#                 log_name="Adversary",
#                 results_dir=Path(self.results_dir) / "Logging",
#                 log_console=True,
#                 seed=self.seed,
#                 console_log_level=30,
#             )
#             adversary.set_logger(adv_logger)
#             adversary = adversary.load(adversary_path)
            
#         else:
#             adversary = SetupAgent("constant", mdp.info, idx_agent=1)
        
#         collect_dataset = CollectDataset()
#         callbacks = collect_dataset
#         agents = [protagonist, adversary]
    

#         core = self.provide_core(
#                 "multi-agent", agents, mdp, callback_step=callbacks
#             )

#         trajs = core.generate_data(n_episodes=1100, render=False)

#         with open(Path(self.results_dir) / 'train_data', 'wb') as file:
#             pickle.dump(trajs[:1000], file)
       
#         with open(Path(self.results_dir) / 'test_data', 'wb') as file:
#             pickle.dump(trajs[1000:], file)
        


#         return protagonist, adversary, mdp
    
#     def load_models(self, protagonist_path, adversary_path, constant = True):
        
#         mdp = self.provide_mdp()
#         if type(self.new_adv_max_force) == float:
#             self.update_adversary(mdp, self.new_adv_max_force)

#         protagonist = SetupAgent(
#             self.agent, mdp.info, idx_agent=0, use_cuda=self.use_cuda
#         )
#         prot_logger = Logger(
#             log_name="Protagonist",
#             results_dir=Path(self.results_dir) / "Logging",
#             log_console=True,
#             seed=self.seed,
#             console_log_level=30,
#         )
#         protagonist.set_logger(prot_logger)
#         protagonist = protagonist.load(protagonist_path)
#         if constant == False:
#             if "quadruped" in self._domain_name:
#                 adversary = SetupAgent(
#                     self.agent,
#                     mdp.info,
#                     idx_agent=1,
#                     use_cuda=self.use_cuda,
#                     adv_n_features=64,
#                 )
#             else:
#                 adversary = SetupAgent(
#                     self.agent, mdp.info, idx_agent=1, use_cuda=self.use_cuda
#                 )
#             adv_logger = Logger(
#                 log_name="Adversary",
#                 results_dir=Path(self.results_dir) / "Logging",
#                 log_console=True,
#                 seed=self.seed,
#                 console_log_level=30,
#             )
#             adversary.set_logger(adv_logger)
#             adversary = adversary.load(adversary_path)
            
#         else:
#             adversary = SetupAgent("constant", mdp.info, idx_agent=1)
        
#         collect_dataset = CollectDataset()
#         callbacks = collect_dataset
#         agents = [protagonist, adversary]
    

#         core = self.provide_core(
#                 "multi-agent", agents, mdp, callback_step=callbacks
#             )

#         core.evaluate(n_episodes=5, render=False)

#         cumulative_reward_per_episode = compute_J(core.callback_step.get(), idx_agent=0)
#         mean_reward = np.mean(cumulative_reward_per_episode)

#         print('mean reward: ', mean_reward)
#         return protagonist, adversary, mdp

#     def save_agents(self, agents, results_dir, seed, full_save=False):
#         protagonist_filename = "exp_" + str(seed) + "_protagonist.zip"
#         agents[0].save(Path(results_dir) / protagonist_filename, full_save=full_save)

#         if self.constant == False:
#             adversary_filename = "exp_" + str(seed) + "_adversary.zip"
#             agents[1].save(Path(results_dir) / adversary_filename, full_save=full_save)


#     def train_protagonist(self):
#         def setup():
#             print('train self.metric_ranges', self.metric_ranges)
#             mdp = self.provide_mdp()

#             mdp.env.physics.change_first_metric(self.first_metric_value)
#             mdp.env.physics.change_second_metric(self.second_metric_value)

#             print('self: ', self.first_metric_value, self.second_metric_value)

            
#             if type(self.new_adv_max_force) == float:
#                 self.update_adversary(mdp, self.new_adv_max_force)

                

#             protagonist = SetupAgent(
#                 self.agent, mdp.info, idx_agent=0, use_cuda=self.use_cuda
#             )
#             prot_logger = Logger(
#                 log_name="Protagonist",
#                 results_dir=Path(self.results_dir) / "Logging",
#                 log_console=True,
#                 seed=self.seed,
#                 console_log_level=30,
#             )
#             protagonist.set_logger(prot_logger)

#             if self.constant == False:
#                 if "quadruped" in self._domain_name:
#                     adversary = SetupAgent(
#                         self.agent,
#                         mdp.info,
#                         idx_agent=1,
#                         use_cuda=self.use_cuda,
#                         adv_n_features=64,
#                         )
#                 else:
#                     adversary = SetupAgent(
#                         self.agent, mdp.info, idx_agent=1, use_cuda=self.use_cuda
#                     )
                
                
#             else:
#                 adversary = SetupAgent("constant", mdp.info, idx_agent=1)

#             adv_logger = Logger(
#                     log_name="Adversary",
#                     results_dir=Path(self.results_dir) / "Logging",
#                     log_console=True,
#                     seed=self.seed,
#                     console_log_level=30,
#                 )
#             adversary.set_logger(adv_logger)
            
#             if self.constant == False:
#                 agents = [protagonist, adversary]
#             else:
#                 agents = [protagonist]

#             core = self.provide_core("multi-agent", agents, mdp)
            

#             return core, prot_logger, adv_logger

#         core, prot_logger, adv_logger = setup()

#         # Train agents
#         mean_reward_vs_adversary_progress = []
#         mean_reward_without_adversary_progress = []
#         for i in trange(self.n_total_iterations, leave=False):
#             # Optimization of adversary # do not optimize adversary
#             # for _ in range(self.n_iterations_per_agent):
#             #     core.learn(
#             #         n_steps=self.n_steps_per_iteration,
#             #         n_episodes=self.n_episodes_per_iteration,
#             #         n_steps_per_fit_per_agent=self.get_n_steps_per_fit_per_agent(
#             #             len(core.agent), idx_agent=1
#             #         ),
#             #         quiet=False,
#             #         render=self.bool_render,
#             #     )
#             # Optimization of protagonist
#             for _ in range(self.n_iterations_per_agent):
#                 core.learn(
#                     n_steps=self.n_steps_per_iteration,
#                     n_episodes=self.n_episodes_per_iteration,
#                     n_steps_per_fit_per_agent=self.get_n_steps_per_fit_per_agent(
#                         len(core.agent), idx_agent=0
#                     ),
#                     quiet=False,
#                     render=self.bool_render,
#                 )
#             # # Evaluation of iteration
#             # # Vs adversary
#             # mean_reward_vs_adversary = self.evaluate_vs_adversary(
#             #     core.agent[0],
#             #     core.agent[1],
#             #     n_episodes=int(self.n_evaluation_episodes / 10),
#             # )
#             # mean_reward_vs_adversary_progress.append(mean_reward_vs_adversary)
#             # msg_vs_adv = (
#             #     "Experiment iteration {}:  \t Mean reward vs adversary: {}".format(
#             #         i, mean_reward_vs_adversary
#             #     )
#             # )
#             # prot_logger.info(msg_vs_adv)

#             # # Without adversary
#             # mean_reward_without_adversary = self.evaluate_without_adversary(
#             #     core.agent[0], n_episodes=int(self.n_evaluation_episodes / 10)
#             # )
#             # mean_reward_without_adversary_progress.append(mean_reward_without_adversary)
#             # msg_without_adv = (
#             #     "Experiment iteration {}:  \t Mean reward without adversary: {}".format(
#             #         i, mean_reward_without_adversary
#             #     )
#             # )
#             # prot_logger.info(msg_without_adv)

#             # # Save best agents
#             # prot_logger.log_best_agent(core.agent[0], mean_reward_vs_adversary)
#             # adv_logger.log_best_agent(core.agent[1], -mean_reward_vs_adversary)

#         self.save_agents(
#             core.agent,
#             Path(self.results_dir) / "Training",
#             self.seed,
#             full_save=False,
#         )

#         # ## Extract data
#         # data = {}
#         # # Mean reward per iteration
#         # data[
#         #     "exp_" + str(self.seed) + "_mean_reward_vs_adversary_per_iteration"
#         # ] = mean_reward_vs_adversary_progress
#         # data[
#         #     "exp_" + str(self.seed) + "_mean_reward_without_adversary_per_iteration"
#         # ] = mean_reward_without_adversary_progress

#         # # Protagonist data
#         # data["exp_" + str(self.seed) + "_prot_temp_per_training_step"] = core.agent[
#         #     0
#         # ].temperature_data
#         # data[
#         #     "exp_" + str(self.seed) + "_prot_batch_mean_entropy_per_training_step"
#         # ] = core.agent[0].entropy_data
#         # data[
#         #     "exp_" + str(self.seed) + "_prot_actor_loss_per_training_step"
#         # ] = core.agent[0].actor_loss_data
#         # data[
#         #     "exp_" + str(self.seed) + "_prot_critic_loss_per_training_step"
#         # ] = core.agent[0].critic_loss_data

#         # # Adversary data
#         # data["exp_" + str(self.seed) + "_adv_temp_per_training_step"] = core.agent[
#         #     1
#         # ].temperature_data
#         # data[
#         #     "exp_" + str(self.seed) + "_adv_batch_mean_entropy_per_training_step"
#         # ] = core.agent[1].entropy_data
#         # data[
#         #     "exp_" + str(self.seed) + "_adv_actor_loss_per_training_step"
#         # ] = core.agent[1].actor_loss_data
#         # data[
#         #     "exp_" + str(self.seed) + "_adv_critic_loss_per_training_step"
#         # ] = core.agent[1].critic_loss_data

#         # # Action data
#         # data[
#         #     "exp_" + str(self.seed) + "_prot_action_norm_per_training_step"
#         # ] = core.action_norms[0]
#         # data[
#         #     "exp_" + str(self.seed) + "_adv_action_norm_per_training_step"
#         # ] = core.action_norms[1]

#         # self.save_data(Path(self.results_dir) / "Training", **data)
#         if self.constant == False: 
#             return core.agent[0], core.agent[1]
#         else:
#             return core.agent
