import os
import sys
import click
import json
import torch
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm
import re
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(project_root)

from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv
from configs.default import default_config
from rlkit.torch.networks import stochastic_actor2


def get_sample(env,policy,buffer,max_path_length=200):
    with torch.no_grad():
        env_step = 0
        o = env.reset()
        while env_step<max_path_length:
            env_step += 1
            a = policy.select_action(o,deterministic=True)
            noise = np.random.normal(0, 1, size=a.shape)
            a_noisy = a + noise
            a_noisy = np.clip(a_noisy, env.action_space.low, env.action_space.high)
            if env_step < 30:
                a_noisy = env.action_space.sample()
            next_o, r, d, env_info = env.step(a_noisy)
            o = next_o
            buffer.append({
            "obs": o,
            "action": a,
            "reward": r,
            "next_obs": next_o,
            "done": d
            })
            if d:
                break

def experiment(test_env,variant,n_samples=100000):
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
    if test_env == 'cheetah-vel':
        env.set_velocity(2.0) # set velocity (+2)
    elif test_env == 'cheetah-dir':
        env.set_direction(-1) # set direction (backward)
    elif test_env == 'ant-goal':
        env.set_goal_position(1.5*np.pi,3) # set goal (angle = 1.5 pi, radius = 3)
    elif test_env == 'ant-dir':
        env.set_direction(1.5*np.pi) # set direction (angle = 0)
    elif test_env =='humanoid-dir':
        env.set_direction(0) # set direction (angle = 0)
    elif 'params' in test_env:
        env.set_test_task()

    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    latent_action_dim = variant['sr_params']['latent_action_dim']
    net_size = variant['net_size']

    # device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    policy = stochastic_actor2(obs_dim,
                               action_dim,
                               net_size,
                               latent_dim=latent_action_dim).to(device)
    
    pattern = re.compile(r'expert_policy\((\d+)\)\.pt')
    numbers = []
    path = f'./{test_env}'
    for filename in os.listdir(path):
        match = pattern.fullmatch(filename)
        if match:
            number = int(match.group(1))
            numbers.append(number)
    if numbers:
        max_number = max(numbers)
        print(f'Loading expert_policy({max_number}).pt')
    else:
        raise Exception("There is no file to read.")
    policy_params = torch.load(f'./{test_env}/expert_policy({max_number}).pt')
    policy.load_state_dict(policy_params)
    epi_length = variant['algo_params']['max_path_length']
    buffer = []
    pbar = tqdm(total=n_samples,desc="Getting samples...")
    while len(buffer) < n_samples:
        get_sample(env, policy, buffer, epi_length)
        pbar.update(len(buffer) - pbar.n)  # 이번에 실제로 추가된 개수만큼
    pbar.close()
    df = pd.DataFrame(buffer)
    pq.write_table(pa.Table.from_pandas(df), f'./{test_env}/expert_data.parquet')

def deep_update_dict(fr, to):
    ''' update dict of dicts with new values '''
    # assume dicts have same keys
    for k, v in fr.items():
        if type(v) is dict:
            deep_update_dict(v, to[k])
        else:
            to[k] = v
    return to

@click.command()
@click.option('--env',default=None)
@click.option('--n_samples',default=100000)

def main(env,n_samples):
    config = f'../configs/{env}.json'
    variant = default_config
    if config:
        with open(os.path.join(config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)
    experiment(env,variant,n_samples)

if __name__ == "__main__":
    main()