import numpy as np
from mushroom_rl_extensions.core.core import Core
from tqdm import tqdm
import pickle

class MultiAgentCore(Core):
    

    def generate_data(
        self,
        initial_states=None,
        n_steps=None,
        n_episodes=None,
        render=False,
        quiet=False,
    ):
        """
        This function moves the agent in the environment using its policy.
        The agent is moved for a provided number of steps, episodes, or from
        a set of initial states for the whole episode. By default, the
        environment is reset.

        Args:
            initial_states (np.ndarray, None): the starting states of each
                episode;
            n_steps (int, None): number of steps to move the agent;
            n_episodes (int, None): number of episodes to move the agent;
            render (bool, False): whether to render the environment or not;
            quiet (bool, False): whether to show the progress bar or not.

        """
        fit_condition_per_agent = list()
        for idx_agent in range(len(self.agent)):
            fit_condition_per_agent.append(lambda: False)
        
        return self._collect_data(
            n_steps, n_episodes, fit_condition_per_agent, render, quiet, initial_states
        )
    def _collect_data(
        self,
        n_steps,
        n_episodes,
        fit_condition_per_agent,
        render,
        quiet,
        initial_states=None,
    ):
        assert (
            n_episodes is not None
            and n_steps is None
            and initial_states is None
            or n_episodes is None
            and n_steps is not None
            and initial_states is None
            or n_episodes is None
            and n_steps is None
            and initial_states is not None
        )

        self._n_episodes = (
            len(initial_states) if initial_states is not None else n_episodes
        )

        if n_steps is not None:
            move_condition = lambda: self._total_steps_counter < n_steps

            steps_progress_bar = tqdm(
                total=n_steps, dynamic_ncols=True, disable=quiet, leave=False
            )
            episodes_progress_bar = tqdm(disable=True)
        else:
            move_condition = lambda: self._total_episodes_counter < self._n_episodes

            steps_progress_bar = tqdm(disable=True)
            episodes_progress_bar = tqdm(
                total=self._n_episodes, dynamic_ncols=True, disable=quiet, leave=False
            )
        return self._run_data(
            move_condition,
            fit_condition_per_agent,
            steps_progress_bar,
            episodes_progress_bar,
            render,
            initial_states,
        )
    def _run_data(
        self,
        move_condition,
        fit_condition_per_agent,
        steps_progress_bar,
        episodes_progress_bar,
        render,
        initial_states,
    ):
        self._total_episodes_counter = 0
        self._total_steps_counter = 0
        self._current_episodes_counter_per_agent = np.zeros(len(self.agent), dtype=int)
        self._current_steps_counter_per_agent = np.zeros(len(self.agent), dtype=int)

        dataset_per_agent = [list() for _ in range(len(self.agent))]
        last = True

        trajs = []

        total_states = []
        total_actions = []
        total_adv_act = []
        total_next_states = []
        total_rewards = []
        total_absorb = []
        total_dones = []

        state_space = self.mdp.observation_space
        low, high = state_space.low, state_space.high



        
        while move_condition():
            if last:
                self.reset(initial_states)
                
                if len(total_states) != 0:

                    total_states = np.array(total_states)
                    total_actions = np.array(total_actions)
                    total_adv_act = np.array(total_adv_act)
                    total_next_states = np.array(total_next_states)
                    total_rewards = np.array(total_rewards)
                    total_dones = np.array(total_dones)
                    total_absorb = np.array(total_absorb)
                    print(total_states.shape[0])
                    random_index = np.random.randint(total_states.shape[0])
                    
                    # Sample a random state within the state space bounds
                    query_state = total_states[random_index, :]
                    optimal_action = total_actions[random_index]

                    traj = {
                        'query_state': query_state,
                        'optimal_action': optimal_action,
                        'context_states': total_states,
                        'context_actions': total_actions,
                        'context_next_states': total_next_states,
                        'context_rewards': total_rewards,
                        'dones': total_dones,
                        'absorb':total_absorb,
                        'context_adv_act':total_adv_act
                    }
                    trajs.append(traj)

                    total_states = []
                    total_actions = []
                    total_adv_act = []
                    total_next_states = []
                    total_rewards = []
                    total_absorb = []
                    total_dones = []


            sample = self._step(render)

            

            # Quadruped Success
            if "Quadruped" in type(self.mdp.env.task).__name__:
                goal_pos = np.array([10, 0])
                torso_pos = self.mdp.env.physics.named.data.geom_xpos["torso"][:2]
                dist = np.linalg.norm(goal_pos - torso_pos)
                success = dist < 1.0

            self.callback_step([sample])

            self._total_steps_counter += 1
            self._current_steps_counter_per_agent += 1
            steps_progress_bar.update(1)

            last = sample[-1] # state, action, reward, next_state, absorbing, done
            total_states.append(sample[0])
            total_actions.append(sample[1][0])
            total_adv_act.append(sample[1][1])
            total_rewards.append(sample[2])
            total_next_states.append(sample[3])
            total_absorb.append(sample[4])
            total_dones.append(sample[5])
            
            
            if last:
                self._total_episodes_counter += 1
                self._current_episodes_counter_per_agent += 1
                episodes_progress_bar.update(1)

            [
                dataset.append(sample)
                for idx_agent, dataset in enumerate(dataset_per_agent)
            ]

            for idx_agent, fit_condition in enumerate(fit_condition_per_agent):
                if fit_condition():
                    self.agent[idx_agent].fit(dataset_per_agent[idx_agent])
                    self._current_episodes_counter_per_agent[idx_agent] = 0
                    self._current_steps_counter_per_agent[idx_agent] = 0

                    if idx_agent == 0:
                        for c in self.callbacks_fit:
                            c(dataset_per_agent[0])
                    else:
                        pass
                        # ToDo: Introduce callbacks for adversary (?)
                    dataset_per_agent[
                        idx_agent
                    ] = (
                        list()
                    )  # fit stores data in agent replay buffer, so core's replay buffer can be reset

        for agent in self.agent:
            agent.stop()
        self.mdp.stop()

        steps_progress_bar.close()
        episodes_progress_bar.close()

        ######################################3333
        total_states = np.array(total_states)
        total_actions = np.array(total_actions)
        total_adv_act = np.array(total_adv_act)
        total_next_states = np.array(total_next_states)
        total_rewards = np.array(total_rewards)
        total_dones = np.array(total_dones)
        total_absorb = np.array(total_absorb)
        random_index = np.random.randint(total_states.shape[0])
        
        # Sample a random state within the state space bounds
        query_state = total_states[random_index, :]
        optimal_action = total_actions[random_index]

        traj = {
            'query_state': query_state,
            'optimal_action': optimal_action,
            'context_states': total_states,
            'context_actions': total_actions,
            'context_next_states': total_next_states,
            'context_rewards': total_rewards,
            'dones': total_dones,
            'absorb':total_absorb,
            'context_adv_act':total_adv_act
        }
        ##########################################3
        trajs.append(traj)
        print('len trajs: ', len(trajs))

        return trajs # just protagonist dataset

    def _run_impl(
        self,
        move_condition,
        fit_condition_per_agent,
        steps_progress_bar,
        episodes_progress_bar,
        render,
        initial_states,
    ):
        self._total_episodes_counter = 0
        self._total_steps_counter = 0
        self._current_episodes_counter_per_agent = np.zeros(len(self.agent), dtype=int)
        self._current_steps_counter_per_agent = np.zeros(len(self.agent), dtype=int)

        dataset_per_agent = [list() for _ in range(len(self.agent))]
        last = True

     

        
        while move_condition():
            if last:
                self.reset(initial_states)

            sample = self._step(render)

            

            # Quadruped Success
            if "Quadruped" in type(self.mdp.env.task).__name__:
                goal_pos = np.array([10, 0])
                torso_pos = self.mdp.env.physics.named.data.geom_xpos["torso"][:2]
                dist = np.linalg.norm(goal_pos - torso_pos)
                success = dist < 1.0

            self.callback_step([sample])

            self._total_steps_counter += 1
            self._current_steps_counter_per_agent += 1
            steps_progress_bar.update(1)

            last = sample[-1] # state, action, reward, next_state, absorbing, done
         
            
            if last:
                self._total_episodes_counter += 1
                self._current_episodes_counter_per_agent += 1
                episodes_progress_bar.update(1)

            [
                dataset.append(sample)
                for idx_agent, dataset in enumerate(dataset_per_agent)
            ]

            for idx_agent, fit_condition in enumerate(fit_condition_per_agent):
                if fit_condition():
                    self.agent[idx_agent].fit(dataset_per_agent[idx_agent])
                    self._current_episodes_counter_per_agent[idx_agent] = 0
                    self._current_steps_counter_per_agent[idx_agent] = 0

                    if idx_agent == 0:
                        for c in self.callbacks_fit:
                            c(dataset_per_agent[0])
                    else:
                        pass
                        # ToDo: Introduce callbacks for adversary (?)
                    dataset_per_agent[
                        idx_agent
                    ] = (
                        list()
                    )  # fit stores data in agent replay buffer, so core's replay buffer can be reset

        for agent in self.agent:
            agent.stop()
        self.mdp.stop()

        steps_progress_bar.close()
        episodes_progress_bar.close()
        return dataset_per_agent[0]  # just protagonist dataset

    def _step(self, render):
        action = list()
        for idx_agent in range(len(self.agent)):
            action.append(self.agent[idx_agent].draw_action(self._state))
            self.action_norms[idx_agent].append(np.linalg.norm(action[idx_agent]))

        next_state, reward, absorbing, info = self.mdp.step(action)

        self._episode_steps += 1

        if render:
            render_info = {"action": action, "reward": reward}
            self.mdp.render(render_info)

        last = not (self._episode_steps < self.mdp.info.horizon and not absorbing)

        state = self._state
        next_state = self._preprocess(next_state.copy())
        self._state = next_state

        return state, action, reward, next_state, absorbing, last

    def reset(self, initial_states=None):
        """
        Reset the state of the mdp and agents.

        """
        if initial_states is None or self._total_episodes_counter == self._n_episodes:
            initial_state = None
        else:
            initial_state = initial_states[self._total_episodes_counter]

        self._state = self._preprocess(self.mdp.reset(initial_state).copy())

        for agent in self.agent:
            agent.episode_start()
            agent.next_action = None
        self._episode_steps = 0
