'''Script used to play with trained agents.'''

import argparse
import os
import torch

import numpy as np
import yaml

import tonic  # noqa
from matplotlib import pyplot as plt

N = 25
def func(env):
    print(env)
    if 'ostrich' in env:
        return lambda: eval(env)

    def build_env(identifier=0):
        build = env[:-1]
        build = build + f', identifier={identifier})'
        return eval(build)
    return build_env


def play_gym(agent, environment):
    '''Launches an agent in a Gym-based environment.'''
    environment = tonic.environments.distribute(lambda identifier=0: environment)

    observations, muscles_dep = environment.start()
    environment.render()
    environment.render_substep()

    score = 0
    length = 0
    min_reward = float('inf')
    max_reward = -float('inf')
    global_min_reward = float('inf')
    global_max_reward = -float('inf')
    steps = 0
    episodes = 0
    while True:
        actions = agent.test_step(observations, steps, muscles_dep)
        observations, muscles_dep, infos = environment.step(actions)
        agent.test_update(**infos, steps=steps)
        environment.render()

        steps += 1
        reward = infos['rewards'][0]
        score += reward
        min_reward = min(min_reward, reward)
        max_reward = max(max_reward, reward)
        global_min_reward = min(global_min_reward, reward)
        global_max_reward = max(global_max_reward, reward)
        length += 1

        if infos['resets'][0]:
            term = infos['terminations'][0]
            episodes += 1

            print()
            print(f'Episodes: {episodes:,}')
            print(f'Score: {score:,.3f}')
            print(f'Length: {length:,}')
            print(f'Terminal: {term:}')
            print(f'Min reward: {min_reward:,.3f}')
            print(f'Max reward: {max_reward:,.3f}')
            print(f'Global min reward: {min_reward:,.3f}')
            print(f'Global max reward: {max_reward:,.3f}')

            score = 0
            length = 0
            min_reward = float('inf')
            max_reward = -float('inf')


def play_control_suite(agent, environment):
    '''Launches an agent in a DeepMind Control Suite-based environment.'''

    from dm_control import viewer

    class Wrapper:
        '''Wrapper used to plug a Tonic environment in a dm_control viewer.'''

        def __init__(self, environment):
            self.environment = environment
            self.unwrapped = environment.unwrapped
            self.action_spec = self.unwrapped.environment.action_spec
            self.physics = self.unwrapped.environment.physics
            self.infos = None
            self.steps = 0
            self.episodes = 0
            self.min_reward = float('inf')
            self.max_reward = -float('inf')
            self.global_min_reward = float('inf')
            self.global_max_reward = -float('inf')
            # self.pca = torch.load('./param_files/pca_matrix_30.pt')

        def reset(self):
            '''Mimics a dm_control reset for the viewer.'''
            self.observations = self.environment.reset()[None]
            self.muscles_dep = self.environment.muscles_dep

            self.score = 0
            self.length = 0
            self.min_reward = float('inf')
            self.max_reward = -float('inf')
            self.max_vel = -100
            self.lce = []
            self.m_act = []
            self.actions = []

            return self.unwrapped.last_time_step

        def step(self, actions):
            '''Mimics a dm_control step for the viewer.'''

            assert not np.isnan(actions.sum())
            # actions = np.einsum('ki,ji->kj', actions[:,:6], self.pca)
            # lowd_action = np.random.normal(0, 1, size=(1, 30))
            # actions = np.einsum('ki,ji->kj', lowd_action, self.pca)
            ob, rew, term, _ = self.environment.step(actions[0])
            self.lce.append(self.physics.tendon_states()[:N].copy())
            self.m_act.append(self.physics.muscle_activations()[:N].copy())
            self.actions.append(actions[:N].copy())
            self.score += rew
            self.length += 1
            self.min_reward = min(self.min_reward, rew)
            self.max_reward = max(self.max_reward, rew)
            self.global_min_reward = min(self.global_min_reward, rew)
            self.global_max_reward = max(self.global_max_reward, rew)
            if self.physics.horizontal_velocity() > self.max_vel:
                self.max_vel = self.physics.horizontal_velocity()
            timeout = self.length == self.environment.max_episode_steps
            done = term or timeout

            if done:
                self.episodes += 1
                print()
                print(f'Episodes: {self.episodes:,}')
                print(f'Score: {self.score:,.3f}')
                print(f'Length: {self.length:,}')
                print(f'Terminal: {term:}')
                print(f'Min reward: {self.min_reward:,.3f}')
                print(f'Max reward: {self.max_reward:,.3f}')
                print(f'Global min reward: {self.min_reward:,.3f}')
                print(f'Global max reward: {self.max_reward:,.3f}')
                print(f'Max velocity: {self.max_vel:,.3f}')

            self.observations = ob[None]
            self.muscles_dep = self.environment.muscles_dep
            self.infos = dict(
                observations=ob[None], rewards=np.array([rew]),
                resets=np.array([done]), terminations=np.array([term]))
            self.done = done
            return self.unwrapped.last_time_step

    # Wrap the environment for the viewer.
    environment = Wrapper(environment)

    def flatten(observation):
        '''Turns OrderedDict observations into vectors.'''
        observation = [np.array([o]) if np.isscalar(o) else o.ravel()
                       for o in observation.values()]
        return np.concatenate(observation, axis=0)


    def policy(timestep):
        '''Mimics a dm_control policy for the viewer.'''

        if environment.infos is not None:
            agent.test_update(**environment.infos, steps=environment.steps)
            environment.steps += 1
        # action = agent.test_step(environment.observations, environment.steps, environment.muscles_dep)
        return agent.test_step(environment.observations, environment.steps, environment.muscles_dep)
        # return agent.test_step(environment.observations, environment.steps)

    # Launch the viewer with the wrapped environment and policy.
    # viewer.launch(environment, policy)
    #EPISODES = 50
    EPISODES = N
    rets = []
    motion = []
    ground_left = []
    ground_right = []
    foot_left = [] 
    foot_right = [] 
    x_lefts = []
    z_lefts = []
    x_rights = []
    z_rights = []

    min_pes = 1000
    max_pes = -1000
    STOP = 0
    for i in range(EPISODES):
        ep_return = 0
        state = flatten(environment.reset().observation)
        ep_steps = 0
        max_speed = -100
        if STOP:
            break
        while True:
            action = policy(state)
            timestep = environment.step(action) 
            next_state = flatten(timestep.observation)
            ep_return += timestep.reward 
            if environment.done:
                print(environment.physics.data.qpos[0] > 22)
                print(f'made it {environment.physics.data.qpos[0]} meters')
                # rets.append(environment.physics.data.qpos[0] // 20)
                rets.append(environment.max_vel)
                print(rets)
                break
            state = next_state.copy()
            motion.append(environment.physics.data.qpos[:3].copy())
            speed = environment.physics.horizontal_velocity()
            if  speed > max_speed:
                max_speed = speed
            if ep_steps >= int(20 / 0.025):
                # foot_left.append(environment.physics.named.data.xpos['l_pes', 'x'].copy() - environment.physics.named.data.geom_xpos['pelvis', 'x'].copy())
                # foot_right.append(environment.physics.named.data.xpos['r_pes', 'x'].copy() - environment.physics.named.data.geom_xpos['pelvis', 'x'].copy())
                foot_left.append(environment.physics.named.data.xpos['l_pes', 'x'].copy() -
                                 environment.physics.data.qpos[0])
                foot_right.append(environment.physics.named.data.xpos['r_pes', 'x'].copy() -
                                  environment.physics.data.qpos[0])
                if environment.physics.named.data.xpos['r_pes', 'z'] < 0.1:
                    ground_right.append(1)
                else:
                    ground_right.append(0)
                if environment.physics.named.data.xpos['l_pes', 'z'] < 0.1:
                    ground_left.append(1)
                else:
                    ground_left.append(0)

                x_lefts.append(environment.physics.named.data.xpos['l_pes', 'x'].copy() - environment.physics.data.qpos[0])
                z_lefts.append(environment.physics.named.data.xpos['l_pes', 'z'].copy() - environment.physics.data.qpos[2])
                x_rights.append(environment.physics.named.data.xpos['r_pes', 'x'].copy() - environment.physics.data.qpos[0])
                z_rights.append(environment.physics.named.data.xpos['r_pes', 'z'].copy() - environment.physics.data.qpos[2])
                STOP = 0

            ep_steps += 1
    #plt.plot([x[0] for x in motion], [x[-1] for x in motion], 'x')
    #plt.plot(foot_left)
    #plt.plot(foot_right)
    #plt.show()
    # for left, right, time in zip(ground_left, ground_right, range(len(ground_right))):
    #     if left > 0:
    #         plt.plot(time, 1, 'r|', markersize=12)
    #         plt.xlim([500,800])
    #     if right > 0:
    #         plt.plot(time, 1.1, 'b|')
    #         plt.xlim([500,800])
    print('Duty factor right')
    print(np.mean(ground_right))
    print('Duty factor left')
    print(np.mean(ground_left))
    #descriptor = 'td4_fast'
    descriptor = 'mpo_fast'
    #descriptor = 'depmpo_fast'
    np.save(f'ground_right_{descriptor}_ever.npy', ground_right)
    np.save(f'ground_left_{descriptor}_ever.npy', ground_left)
    np.save(f'foot_deviation_left_{descriptor}_ever.npy', foot_left)
    np.save(f'foot_deviation_right_{descriptor}_ever.npy', foot_right)
    np.save(f'foot_locus_left_x_{descriptor}_ever.npy', x_lefts)
    np.save(f'foot_locus_left_z_{descriptor}_ever.npy', z_lefts) 
    np.save(f'foot_locus_right_x_{descriptor}_ever.npy', x_rights)
    np.save(f'foot_locus_right_z_{descriptor}_ever.npy', z_rights)
    np.save(f'speed_{descriptor}.npy', rets)

    # plt.show()
    return rets



def play(path, checkpoint, seed, header, agent, environment):
    '''Reloads an agent and an environment from a previous experiment.'''

    if agent is not None:
        agent = None
        header = None
        environment = None
    checkpoint_path = None

    if path:
        tonic.logger.log(f'Loading experiment from {path}')

        # Use no checkpoint, the agent is freshly created.
        if checkpoint == 'none' or agent is not None:
            tonic.logger.log('Not loading any weights')

        else:
            checkpoint_path = os.path.join(path, 'checkpoints')
            if not os.path.isdir(checkpoint_path):
                tonic.logger.error(f'{checkpoint_path} is not a directory')
                checkpoint_path = None

            # List all the checkpoints.
            checkpoint_ids = []
            for file in os.listdir(checkpoint_path):
                if file[:5] == 'step_':
                    checkpoint_id = file.split('.')[0]
                    checkpoint_ids.append(int(checkpoint_id[5:]))

            if checkpoint_ids:
                # Use the last checkpoint.
                if checkpoint == 'last':
                    checkpoint_id = max(checkpoint_ids)
                    checkpoint_path = os.path.join(
                        checkpoint_path, f'step_{checkpoint_id}')

                # Use the specified checkpoint.
                else:
                    checkpoint_id = int(checkpoint)
                    if checkpoint_id in checkpoint_ids:
                        checkpoint_path = os.path.join(
                            checkpoint_path, f'step_{checkpoint_id}')
                    else:
                        tonic.logger.error(f'Checkpoint {checkpoint_id} '
                                           f'not found in {checkpoint_path}')
                        checkpoint_path = None

            else:
                tonic.logger.error(f'No checkpoint found in {checkpoint_path}')
                checkpoint_path = None

        # Load the experiment configuration.
        arguments_path = os.path.join(path, 'config.yaml')
        with open(arguments_path, 'r') as config_file:
            config = yaml.load(config_file, Loader=yaml.FullLoader)
        config = argparse.Namespace(**config)
        print(config)
        header = header or config.header
        agent = agent or config.agent
        environment = environment or config.test_environment
        environment = environment or config.environment

    env_str_orig = environment
    # Run the header first, e.g. to load an ML framework.
    if header:
        exec(header)

    # Build the agent.
    if not agent:
        raise ValueError('No agent specified.')
    agent = eval(agent)
    # Build the environment.
    environment = func(environment)()
    environment.seed(seed)

    # Initialize the agent.
    agent.initialize(
        observation_space=environment.observation_space,
        action_space=environment.action_space, seed=seed)

    # Load the weights of the agent form a checkpoint.
    if checkpoint_path:
        agent.load(checkpoint_path, play=True)

    # Play with the agent in the environment.
    if isinstance(environment, tonic.environments.wrappers.ActionRescaler):
        environment_type = environment.env.__class__.__name__
    else:
        environment_type = environment.__class__.__name__

    rets = np.array(play_control_suite(agent, environment), dtype=np.float32)
    print(np.mean(rets))
    # np.save('dep_mpo_obstacles_0cm_uncorrected.npy', rets)
    # np.save(f'td4_speed.npy', rets)


if __name__ == '__main__':
    # Argument parsing.
    parser = argparse.ArgumentParser()
    parser.add_argument('--path')
    parser.add_argument('--checkpoint', default='last')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--header')
    parser.add_argument('--agent')
    parser.add_argument('--environment', '--env')
    args = vars(parser.parse_args())
    play(**args)
