import yaml
import argparse
import joblib
import numpy as np
import os,sys,inspect
import pickle, random

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 
print(sys.path)

from gym.spaces import Dict
from rlkit.envs import get_env

import rlkit.torch.pytorch_util as ptu
from rlkit.launchers.launcher_util import setup_logger, set_seed

from rlkit.torch.networks import FlattenMlp
from rlkit.torch.sac.policies import ReparamTanhMultivariateGaussianLfOPolicy
from rlkit.torch.sac.sac_lfo import SoftActorCritic
from rlkit.torch.irl.disc_models.simple_disc_models import MLPDisc
from rlkit.torch.irl.adv_irl_lfo_cycle import AdvIRL_LfO
from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
from rlkit.envs.wrappers import ScaledEnv

import torch


def experiment(variant):
    with open('demos_listing.yaml', 'r') as f:
        listings = yaml.load(f.read())
    # expert_demos_path = listings[variant['expert_name']]['file_paths'][variant['expert_idx']]
    # buffer_save_dict = joblib.load(expert_demos_path)
    # expert_replay_buffer = buffer_save_dict['train']
    demos_path = listings[variant['expert_name']]['file_paths'][0]
    print("demos_path", demos_path)
    with open(demos_path, 'rb') as f:
        traj_list = pickle.load(f)
    traj_list = random.sample(traj_list, variant['traj_num'])

    env_specs = variant['env_specs']
    env = get_env(env_specs)
    env.seed(env_specs['eval_env_seed'])
    training_env = get_env(env_specs)
    training_env.seed(env_specs['training_env_seed'])

    print('\n\nEnv: {}'.format(env_specs['env_name']))
    print('kwargs: {}'.format(env_specs['env_kwargs']))
    print('Obs Space: {}'.format(env.observation_space))
    print('Act Space: {}\n\n'.format(env.action_space))

    if variant['adv_irl_params']['wrap_absorbing']:
        print('\n\nUSING ABOSORBING STATES\n\n')

    expert_replay_buffer = EnvReplayBuffer(
            variant['adv_irl_params']['replay_buffer_size'],
            env,
            random_seed=np.random.randint(10000)
    )
    for i in range(len(traj_list)):
        expert_replay_buffer.add_path(traj_list[i], absorbing=variant['adv_irl_params']['wrap_absorbing'], env=env)
    
    if variant['scale_env_with_demo_stats']:
        raise NotImplementedError
        # env = ScaledEnv(
        #     env,
        #     obs_mean=buffer_save_dict['obs_mean'],
        #     obs_std=buffer_save_dict['obs_std'],
        #     acts_mean=buffer_save_dict['acts_mean'],
        #     acts_std=buffer_save_dict['acts_std'],
        # )
        # training_env = ScaledEnv(
        #     training_env,
        #     obs_mean=buffer_save_dict['obs_mean'],
        #     obs_std=buffer_save_dict['obs_std'],
        #     acts_mean=buffer_save_dict['acts_mean'],
        #     acts_std=buffer_save_dict['acts_std'],
        # )

    obs_space = env.observation_space
    act_space = env.action_space
    assert not isinstance(obs_space, Dict)
    assert len(obs_space.shape) == 1
    assert len(act_space.shape) == 1
    
    obs_dim = obs_space.shape[0]
    action_dim = act_space.shape[0]

    q_input_dim = obs_dim + action_dim

    if 'qss' in variant['adv_irl_params'].keys():
        if variant['adv_irl_params']['qss']:
            print('QSS!')
            q_input_dim = obs_dim + obs_dim

    # build the policy models
    net_size = variant['policy_net_size']
    num_hidden = variant['policy_num_hidden_layers']
    qf1 = FlattenMlp(
        hidden_sizes=num_hidden * [net_size],
        input_size=q_input_dim,
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=num_hidden * [net_size],
        input_size=q_input_dim,
        output_size=1,
    )
    vf = FlattenMlp(
        hidden_sizes=num_hidden * [net_size],
        input_size=obs_dim,
        output_size=1,
    )
    policy = ReparamTanhMultivariateGaussianLfOPolicy(
        hidden_sizes=num_hidden * [net_size],
        obs_dim=obs_dim,
        action_dim=action_dim,
        state_diff=exp_specs['adv_irl_params']['state_diff'],
    )
    forward_model = FlattenMlp(
        hidden_sizes=num_hidden * [net_size],
        input_size=obs_dim+action_dim,
        output_size=obs_dim,
    )

    if 'sas' in variant['adv_irl_params'].keys():
        assert variant['adv_irl_params']['state_only'] or variant['adv_irl_params']['sas'], "should be state only or sas"
    else:
        assert variant['adv_irl_params']['state_only']

    if variant['adv_irl_params']['wrap_absorbing']:
        obs_dim += 1
    input_dim = 2*obs_dim
    if 'sas' in variant['adv_irl_params'].keys():
        if variant['adv_irl_params']['sas']:
            print('SAS!')
            input_dim = obs_dim + action_dim + obs_dim
    if 'sss' in variant['adv_irl_params'].keys():
        if variant['adv_irl_params']['sss']:
            print('SSS!')
            input_dim = obs_dim + obs_dim + obs_dim

    # build the discriminator model
    disc_model = MLPDisc(
        input_dim,
        num_layer_blocks=variant['disc_num_blocks'],
        hid_dim=variant['disc_hid_dim'],
        hid_act=variant['disc_hid_act'],
        use_bn=variant['disc_use_bn'],
        clamp_magnitude=variant['disc_clamp_magnitude']
    )

    update_both = True
    if 'union_sp' in exp_specs['adv_irl_params']:
        if exp_specs['adv_irl_params']['union_sp']:
            update_both = False

    # set up the algorithm
    trainer = SoftActorCritic(
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        vf=vf,
        update_both=update_both,
        **variant['sac_params']
    )
    algorithm = AdvIRL_LfO(
        env=env,
        training_env=training_env,
        exploration_policy=policy,
        discriminator=disc_model,
        forward_model=forward_model,
        policy_trainer=trainer,
        expert_replay_buffer=expert_replay_buffer,
        **variant['adv_irl_params']
    )

    if ptu.gpu_enabled():
        algorithm.to(ptu.device)
    algorithm.train(pred_obs=True)

    return 1


if __name__ == '__main__':
    # Arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('-e', '--experiment', help='experiment specification file')
    parser.add_argument('-g', '--gpu', help='gpu id', type=int, default=0)
    args = parser.parse_args()
    with open(args.experiment, 'r') as spec_file:
        spec_string = spec_file.read()
        exp_specs = yaml.load(spec_string)

    # make all seeds the same.
    exp_specs['env_specs']['eval_env_seed'] = exp_specs['env_specs']['training_env_seed'] = exp_specs['seed']

    if (exp_specs['num_gpu_per_worker'] > 0) and torch.cuda.is_available():
        print('\n\nUSING GPU\n\n')
        ptu.set_gpu_mode(True, args.gpu)
    exp_id = exp_specs['exp_id']
    exp_prefix = exp_specs['exp_name']

    if 'decay_ratio' not in exp_specs.keys():
        exp_specs['decay_ratio'] = 1.0

    exp_suffix = '--gp-{}--splr-{}--idlr-{}--rs-{}--decay-{}'.format(exp_specs['adv_irl_params']['grad_pen_weight'],
                                                           exp_specs['adv_irl_params']['state_predictor_lr'],
                                                           exp_specs['adv_irl_params']['inverse_dynamic_lr'],
                                                           exp_specs['sac_params']['reward_scale'],
                                                           exp_specs['decay_ratio'])

    if 'union' in exp_specs['adv_irl_params']:
        if exp_specs['adv_irl_params']['union']:
            exp_suffix = '--gp-{}--spalpha-{}--idbeta-{}--rs-{}'.format(exp_specs['adv_irl_params']['grad_pen_weight'],
                                                           exp_specs['adv_irl_params']['state_predictor_alpha'],
                                                           exp_specs['adv_irl_params']['inverse_dynamic_beta'],
                                                           exp_specs['sac_params']['reward_scale'])

    if 'union_sp' in exp_specs['adv_irl_params']:
        if exp_specs['adv_irl_params']['union_sp']:
            exp_suffix = '-sp--gp-{}--spalpha-{}--idlr-{}--rs-{}--inviter-{}'.format(
                exp_specs['adv_irl_params']['grad_pen_weight'],
                exp_specs['adv_irl_params']['state_predictor_alpha'],
                exp_specs['adv_irl_params']['inverse_dynamic_lr'],
                exp_specs['sac_params']['reward_scale'],
                exp_specs['adv_irl_params']['num_inverse_dynamic_updates_per_loop_iter'])

    if 'sas' in exp_specs['adv_irl_params']:
        if exp_specs['adv_irl_params']['sas']:
            exp_suffix = '--sas'+exp_suffix

    if 'sss' in exp_specs['adv_irl_params']:
        if exp_specs['adv_irl_params']['sss']:
            exp_suffix = '--sss'+exp_suffix

    if 'qss' in exp_specs['adv_irl_params']:
        if exp_specs['adv_irl_params']['qss']:
            exp_suffix = '--qss' + exp_suffix

    if 'multi_step' in exp_specs['adv_irl_params']:
        if exp_specs['adv_irl_params']['multi_step']:
            exp_suffix = '--ms-{}'.format(exp_specs['adv_irl_params']['step_num']) + exp_suffix

    if 'cycle' in exp_specs['adv_irl_params']:
        if exp_specs['adv_irl_params']['cycle']:
            exp_suffix = '--cycle' + exp_suffix

    if 'inv_buffer' in exp_specs['adv_irl_params']:
        if exp_specs['adv_irl_params']['inv_buffer']:
            exp_suffix = '--biginvbuffer' + exp_suffix

    if exp_specs['adv_irl_params']['wrap_absorbing']:
        exp_suffix = '--absorbing' + exp_suffix

    if 'state_diff' in exp_specs['adv_irl_params']:
        if exp_specs['adv_irl_params']['state_diff']:
            exp_suffix = '--state_diff' + exp_suffix
    else:
        exp_specs['adv_irl_params']['state_diff'] = False

    if 'sl' in exp_specs['exp_name']:
        exp_suffix = '--splr-{}--idlr-{}'.format(exp_specs['adv_irl_params']['state_predictor_lr'],
                                               exp_specs['adv_irl_params']['inverse_dynamic_lr'])
    
    elif 'gailfo-dp' in exp_specs['exp_name']:
        exp_suffix = '--gp-{}--rs-{}'.format(exp_specs['adv_irl_params']['grad_pen_weight'],
                                             exp_specs['sac_params']['reward_scale'],)


    exp_prefix = exp_prefix + exp_suffix
    seed = exp_specs['seed']
    set_seed(seed)
    setup_logger(exp_prefix=exp_prefix, exp_id=exp_id, variant=exp_specs, seed=seed)

    experiment(exp_specs)
