"""Script used to play with trained agents."""
import argparse
import os

import numpy as np
import tonic  # noqa
import yaml

from .utils import env_tonic_compat
from matplotlib import pyplot as plt


N = 20


def play_gym(agent, environment, env_args=None):
    """Launches an agent in a Gym-based environment."""
    environment = tonic.environments.distribute(lambda identifier=0: environment, env_args=env_args)
    results = []
    #for course in ['2d', '2d_5cm', '2d_10cm', '2d_15cm', '2d_20cm', '2d_30cm']:
    for course in ['2d']:
        #environment.environments[0].set_manually_load_model()
        environment.environments[0].set_obstacle_course(course)



        observations, muscles_dep = environment.start()

        # uncomment if you want to render each step in warmup
        environment.render()
        # environment.render_substep()

        score = 0
        length = 0
        min_reward = float("inf")
        max_reward = -float("inf")
        global_min_reward = float("inf")
        global_max_reward = -float("inf")
        steps = 0
        episodes = 0
        eval_score = []
        largest_x_reached = 0
        maxes = np.zeros_like(environment.action_space.shape)
        right_duration = 0
        left_duration = 0
        left_foot = []
        right_foot = []
        step_lengths_left = []
        step_lengths_right = []
        last_x_right = 0
        last_x_left = 0
        xs_left = []
        xs_right = []
        while True:
            actions = agent.test_step(observations, steps, muscles_dep)
            # print(actions)
            observations, muscles_dep, infos = environment.step(actions)
            agent.test_update(**infos, steps=steps)
            environment.render()
            if hasattr(environment.environments[0].unwrapped, "data"):
                maxes = np.maximum(
                    maxes, environment.environments[0].unwrapped.data.qfrc_actuator
                )
            x_left = environment.environments[0].unwrapped.model.bodies()[-2].com_pos().x
            x_right = environment.environments[0].unwrapped.model.bodies()[-5].com_pos().x
            xs_right.append(environment.environments[0].unwrapped.model.bodies()[-2].com_pos().x - environment.environments[0].unwrapped.model.bodies()[1].com_pos().x)
            xs_left.append(environment.environments[0].unwrapped.model.bodies()[-5].com_pos().x - environment.environments[0].unwrapped.model.bodies()[1].com_pos().x)
            if environment.environments[0].model.dof_values()[1] > 5:
                if environment.environments[0].unwrapped.model.bodies()[-2].com_pos().y < 0.05:
                    left_foot_active = 1
                    #print('left_foot')
                else:
                    left_foot_active = 0
                if environment.environments[0].unwrapped.model.bodies()[-5].com_pos().y < 0.05:
                    right_foot_active = 1
                    #print('right foot')
                else:
                    right_foot_active = 0

                if left_foot_active:
                    left_duration += 1

                if left_duration != 0 and not left_foot_active:
                    left_foot.append(left_duration)
                    left_duration = 0
                    if last_x_left != 0:
                        step_lengths_left.append(x_left - last_x_left)
                    last_x_left = x_left

                if right_foot_active:
                    right_duration += 1

                if right_duration != 0 and not right_foot_active:
                    right_foot.append(right_duration)
                    right_duration = 0
                    if last_x_right != 0:
                        step_lengths_right.append(x_right - last_x_right)
                    last_x_right = x_right

            # print(maxes)
            steps += 1
            reward = infos["rewards"][0]
            score += reward
            min_reward = min(min_reward, reward)
            max_reward = max(max_reward, reward)
            global_min_reward = min(global_min_reward, reward)
            global_max_reward = max(global_max_reward, reward)
            length += 1
            #print(score)
            # print(infos['terminations'])
            largest_x_reached = np.maximum(largest_x_reached, environment.environments[0].model.dof_values()[1])
            if infos["resets"][0]:
                term = infos["terminations"][0]
                episodes += 1

                print()
                print(f"Episodes: {episodes:,}")
                print(f"Score: {score:,.3f}")
                print(f"Length: {length:,}")
                print(f"Terminal: {term:}")
                print(f"Min reward: {min_reward:,.3f}")
                print(f"Max reward: {max_reward:,.3f}")
                print(f"Global min reward: {min_reward:,.3f}")
                print(f"Global max reward: {max_reward:,.3f}")

                score = 0
                length = 0
                min_reward = float("inf")
                max_reward = -float("inf")
                environment.environments[0].model.write_results('sconepy_example')
                #environment.environments[0].manually_load_model()
                print(f'Largest x reached was: {largest_x_reached}')
                if largest_x_reached >= 18:
                    print('Did it!')
                    eval_score.append(1)
                else:
                    eval_score.append(0)
                largest_x_reached = 0

            if episodes == N:
                if 'sconegym' in str(type(environment.environments[0].unwrapped)):
                    print(f'Eval score is {np.mean(eval_score)}')
                    results.append(eval_score)
                    break

    #np.save('obstacle_hyfydy_dep_new.npy', results)
    print(f'{left_foot=}')
    print(f'{right_foot=}')
    left = np.mean(left_foot)
    right = np.mean(right_foot)
    symmetry_time = (left - right) / (0.5 * (left + right))
    print(f'{symmetry_time=}')
    #print(f'{step_lengths_left=}')
    #print(f'{step_lengths_right=}')
    left = np.mean(step_lengths_left)
    right = np.mean(step_lengths_right)
    symmetry_length = (left - right) / (0.5 * (left + right))
    print(f'{symmetry_length=}')
    print(f'Left foot: mean={np.mean(xs_left)} std= {np.std(xs_left)}')
    print(f'Right foot: mean={np.mean(xs_right)} std= {np.std(xs_right)}')
    mean_left = np.mean(xs_left)
    mean_right = np.mean(xs_right)
    symmetry_placement = (mean_left - mean_right)/(0.5 * (mean_left + mean_right))
    return symmetry_time, symmetry_length, mean_left, mean_right, symmetry_placement


def play_control_suite(agent, environment):
    """Launches an agent in a DeepMind Control Suite-based environment."""

    from dm_control import viewer

    class Wrapper:
        """Wrapper used to plug a Tonic environment in a dm_control viewer."""

        def __init__(self, environment):
            self.environment = environment
            self.unwrapped = environment.unwrapped
            self.action_spec = self.unwrapped.environment.action_spec
            self.physics = self.unwrapped.environment.physics
            self.infos = None
            self.steps = 0
            self.episodes = 0
            self.min_reward = float("inf")
            self.max_reward = -float("inf")
            self.global_min_reward = float("inf")
            self.global_max_reward = -float("inf")

        def reset(self):
            """Mimics a dm_control reset for the viewer."""
            self.observations = self.environment.reset()[None]
            self.muscles_dep = self.environment.muscles_dep

            self.score = 0
            self.length = 0
            self.min_reward = float("inf")
            self.max_reward = -float("inf")

            return self.unwrapped.last_time_step

        def step(self, actions):
            """Mimics a dm_control step for the viewer."""
            # print(actions)
            assert not np.isnan(actions.sum())
            ob, rew, term, _ = self.environment.step(actions[0])

            self.score += rew
            self.length += 1
            self.min_reward = min(self.min_reward, rew)
            self.max_reward = max(self.max_reward, rew)
            self.global_min_reward = min(self.global_min_reward, rew)
            self.global_max_reward = max(self.global_max_reward, rew)
            timeout = self.length == self.environment.max_episode_steps
            done = term or timeout

            if done:
                self.episodes += 1
                print()
                print(f"Episodes: {self.episodes:,}")
                print(f"Score: {self.score:,.3f}")
                print(f"Length: {self.length:,}")
                print(f"Terminal: {term:}")
                print(f"Min reward: {self.min_reward:,.3f}")
                print(f"Max reward: {self.max_reward:,.3f}")
                print(f"Global min reward: {self.min_reward:,.3f}")
                print(f"Global max reward: {self.max_reward:,.3f}")

            self.observations = ob[None]
            self.muscles_dep = self.environment.muscles_dep
            self.infos = dict(
                observations=ob[None],
                rewards=np.array([rew]),
                resets=np.array([done]),
                terminations=np.array([term]),
            )

            return self.unwrapped.last_time_step

    # Wrap the environment for the viewer.
    environment = Wrapper(environment)

    def policy(timestep):
        """Mimics a dm_control policy for the viewer."""

        if environment.infos is not None:
            agent.test_update(**environment.infos, steps=environment.steps)
            environment.steps += 1
        return agent.test_step(
            environment.observations, environment.steps, environment.muscles_dep
        )
        # return agent.test_step(environment.observations, environment.steps)

    # Launch the viewer with the wrapped environment and policy.
    viewer.launch(environment, policy)


def play(path, checkpoint, seed, header, agent, environment):
    """Reloads an agent and an environment from a previous experiment."""
    checkpoint_path = None
    prefix = 'folder'
    data1 = ['working_dep_seed0\\cluster_myosuite\\', 'working_dep_seed1\\cluster_myosuite\\', 'working_dep_seed2\\cluster_myosuite\\','working_dep_seed3\\cluster_myosuite\\','working_dep_seed4\\cluster_myosuite\\']
    data2 = ['working_dep_no_dep_seed0\\cluster_myosuite\\', 'working_dep_no_dep_seed1\\cluster_myosuite\\', 'working_dep_no_dep_seed2\\cluster_myosuite\\', 'working_dep_no_dep_seed3\\cluster_myosuite\\', 'working_dep_no_dep_seed4\\cluster_myosuite\\']
    datas = [data1, data2]
    final_data = []
    for data in datas:
        values_left = []
        values_right = []
        values_sym_place = []
        values_sym_time = []
        values_sym_length = []
        for datum in data:
            path = os.path.join(prefix, datum)
            print(path)
            tonic.logger.log(f"Loading experiment from {path}")
            agent = None
            environment = None
            checkpoint = 'last'
            header = None


            # Use no checkpoint, the agent is freshly created.
            if checkpoint == "none" or agent is not None:
                tonic.logger.log("Not loading any weights")

            else:
                checkpoint_path = os.path.join(path, "checkpoints")
                if not os.path.isdir(checkpoint_path):
                    tonic.logger.error(f"{checkpoint_path} is not a directory")
                    checkpoint_path = None

                # List all the checkpoints.
                checkpoint_ids = []
                for file in os.listdir(checkpoint_path):
                    if file[:5] == "step_":
                        checkpoint_id = file.split(".")[0]
                        checkpoint_ids.append(int(checkpoint_id[5:]))

                if checkpoint_ids:
                    # Use the last checkpoint.
                    if checkpoint == "last":
                        checkpoint_id = max(checkpoint_ids)
                        checkpoint_path = os.path.join(
                            checkpoint_path, f"step_{checkpoint_id}"
                        )

                    # Use the specified checkpoint.
                    else:
                        checkpoint_id = int(checkpoint)
                        if checkpoint_id in checkpoint_ids:
                            checkpoint_path = os.path.join(
                                checkpoint_path, f"step_{checkpoint_id}"
                            )
                        else:
                            tonic.logger.error(
                                f"Checkpoint {checkpoint_id} not found in {checkpoint_path}"
                            )
                            checkpoint_path = None

                else:
                    tonic.logger.error(f"No checkpoint found in {checkpoint_path}")
                    checkpoint_path = None

            # Load the experiment configuration.
            arguments_path = os.path.join(path, "config.yaml")
            with open(arguments_path, "r") as config_file:
                config = yaml.load(config_file, Loader=yaml.FullLoader)
            config = argparse.Namespace(**config)
            print(config)
            header = header or config.header
            agent = agent or config.agent
            environment = environment or config.test_environment
            environment = environment or config.environment

            # Run the header first, e.g. to load an ML framework.
            if header:
                exec(header)

            # Build the agent.
            if not agent:
                raise ValueError("No agent specified.")
            agent = eval(agent)

            # Build the environment.
            environment = env_tonic_compat(environment)()
            environment.seed(seed)

            # Initialize the agent.
            agent.initialize(
                observation_space=environment.observation_space,
                action_space=environment.action_space,
                seed=seed,
            )

            # Load the weights of the agent form a checkpoint.
            if checkpoint_path:
                agent.load(checkpoint_path, play=True)

            # Play with the agent in the environment.
            if isinstance(environment, tonic.environments.wrappers.ActionRescaler):
                environment_type = environment.env.__class__.__name__
            else:
                environment_type = environment.__class__.__name__

            if environment_type == "ControlSuiteEnvironment":
                play_control_suite(agent, environment)
            else:
                if "config" in locals() and hasattr(config, "env_args"):
                    env_args = config.env_args
                else:
                    env_args = None
                results = play_gym(agent, environment, env_args)
                values_left.append(results[-3])
                values_right.append(results[-2])
                values_sym_time.append(results[0])
                values_sym_length.append(results[1])
                values_sym_place.append(results[-1])

        mean_left = np.mean(values_left)
        std_left = np.std(values_left)
        mean_right = np.mean(values_right)
        std_right = np.std(values_right)
        x_range = np.arange(2)
        #plt.bar(x_range, [mean_left, mean_right], yerr=[std_left, std_right])
        #plt.boxplot([values_sym_place, values_sym_time, values_sym_length])
        #plt.show()
        #plt.boxplot([values_left, values_right])
        #plt.show()
        final_data.append((values_left, values_right))
    mean_1 = np.mean(np.concatenate([final_data[0][0], final_data[0][1]]))
    mean_2 = np.mean(np.concatenate([final_data[1][0], final_data[1][1]]))
    plt.boxplot([final_data[0][0]-mean_1, final_data[0][1]-mean_1, final_data[1][0]-mean_2, final_data[1][1]-mean_2], patch_artist=True, labels=['left', 'right', 'left', 'right'])
    np.save('dep_left.npy', final_data[0][0])
    np.save('dep_right.npy', final_data[0][1])
    np.save('nodep_left.npy', final_data[1][0])
    np.save('nodep_right.npy', final_data[1][1])

    #plt.plot([x[1] for x in final_data])
    plt.show()


if __name__ == "__main__":
    # Argument parsing.
    parser = argparse.ArgumentParser()
    parser.add_argument("--path")
    parser.add_argument("--checkpoint", default="last")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--header")
    parser.add_argument("--agent")
    parser.add_argument("--environment", "--env")
    args = vars(parser.parse_args())
    play(**args)
