import pickle
#from helper.ant_ppo.rl_code.PPO import PPO
from PPO import PPO
import numpy as np
import os
from gym.envs.mujoco.ant_v3 import AntEnv
from tqdm import tqdm
import torch
from ant_utils import change_xml
import ipdb


def get_traj_save_path(data_dir, mode, T_support, num_trajectories):
    return os.path.join(data_dir, 'ant_ppo_8_leg', f'{mode}_{T_support}_{num_trajectories}.pkl')


def to_numpy(arr):
    return arr.cpu().data.numpy()


def generate_traj_dataset(num_trajectories, test, seed):
    agent_path = '/media/gustaf/039f6885-460f-4da2-92d0-1828ceba36e2/function-shift/helper/ant_ppo/rl_code_8_legs/PPO_model_8.mdl'
    z_dim = 27 + 8
    a_dim = 8
    a_max = 1
    device = 'cuda'

    agent = PPO(z_dim, a_dim, a_max, device)
    agent.load_model(agent_path)


    data_dir = '/media/gustaf/039f6885-460f-4da2-92d0-1828ceba36e2/function-shift/data'
    ant_xml_dir = '/media/gustaf/039f6885-460f-4da2-92d0-1828ceba36e2/function-shift/helper/ant_ppo/rl_code_8_legs/ant_xml'
    xml_file = ant_xml_dir + '/ant_tmp_replace.xml'


    T_policy = 100
    T_support = 30
    T_query = 5
    action_support = np.random.uniform(-1, 1, size=(T_support, 8))
    action_query = np.random.uniform(-1, 1, size=(T_query, 8))

    if not os.path.exists(os.path.join(data_dir, 'ant_ppo_8_leg')):
        os.mkdir(os.path.join(data_dir, 'ant_ppo_8_leg'))

    if test:
        postfix = 'test'
    else:
        postfix = 'train'

    save_path_policy = get_traj_save_path(data_dir, f'ppo_{postfix}', T_support, num_trajectories)
    save_path_fixed = get_traj_save_path(data_dir, f'fixed_{postfix}', T_support, num_trajectories)
    print("gonna save at", save_path_policy)

    rng = np.random.default_rng(seed)

    policy_tasks = []
    tasks = []

    p_min = 0.2
    p_max = 0.6
        
    for traj_id in tqdm(range(num_trajectories)):
        task_params = rng.uniform(size=8) * (p_max - p_min) + p_min
        change_xml(xml_file, ant_xml_dir, task_params)
        env = AntEnv(xml_file)
        
        support_x = []
        support_y = []
        query_x = []
        query_y = []
        
        query_policy_x = []
        query_policy_y = []
        
        # Support data
        old_state = env.reset()
        old_state = old_state[:27]
        
        for action in action_support:
            new_state, _, _, _ = env.step(action)
            new_state = new_state[:27]
            
            x_state_action = np.concatenate([old_state, action])
            
            support_x.append(x_state_action)
            support_y.append(new_state)
                
            old_state = new_state
            
        support_x = np.array(support_x) 
        support_y = np.array(support_y) 
        
        # Query data
        old_state = env.reset()
        old_state = old_state[:27]

        for action in action_query:
            new_state, _, _, _ = env.step(action)
            new_state = new_state[:27]

            x_state_action = np.concatenate([old_state, action])
            query_x.append(x_state_action)
            query_y.append(new_state)
            old_state = new_state
        query_x = np.array(query_x) 
        query_y = np.array(query_y) 
            
        # Policy data
        old_state = env.reset()
        old_state = old_state[:27]


        for t in range(T_policy):
            st = np.expand_dims(np.concatenate((old_state[:27], task_params)), 0)
            st = torch.from_numpy(st).float().to(device)

            at, logprob, sigma = agent.get_action(st, test=True)
            new_state, reward, done, info = env.step(at[0].detach().cpu().numpy())
            new_state = new_state[:27]
            
            x_state_action = np.concatenate([old_state, to_numpy(at)[0]])
            
            query_policy_x.append(x_state_action)
            query_policy_y.append(new_state)
            
            old_state = new_state
            
        query_policy_x = np.array(query_policy_x)
        query_policy_y = np.array(query_policy_y)
            
        policy_tasks.append([query_policy_x, query_policy_y])
        tasks.append([support_x, support_y, query_x, query_y, task_params])
        

    with open(save_path_fixed, 'wb') as f:
        pickle.dump(tasks, f)
        
    with open(save_path_policy, 'wb') as f:
        pickle.dump(policy_tasks, f)
        


if __name__ == '__main__':
    #num_trajectories = 1000
    num_trajectories = 5000
    generate_traj_dataset(num_trajectories, test=False, seed=0)

    num_trajectories = 1000
    generate_traj_dataset(num_trajectories, test=True, seed=1)