import sys, os, time
sys.path.append('./')
import numpy as np
import torch
import gym
from ruamel.yaml import YAML

from main.models.reward import MLPReward
from common.sac_reward_only import ReplayBuffer, SAC

import envs
from utils import system, collect, logger, eval
from utils.plots.train_plot_high_dim import plot_disc
from utils.plots.train_plot import plot_disc as visual_disc

import datetime
import dateutil.tz
import json, copy

def ML_loss(div: str, env, sac_agent, agent_samples, expert_samples, qpos, qvel, reward_func, device, time_step):
    ''' NOTE: only for ML: E_p[r(tau)] - E_q[r(tau)] w.r.t. r
        agent_samples is numpy array of shape (N, T, d) 
        expert_samples is numpy array of shape (N, T, d) or (N, d)
    '''
    assert div in ['maxentirl']
    sA, _, _ = agent_samples
    _, T, d = sA.shape

    sA_vec = torch.FloatTensor(sA).reshape(-1, d).to(device)
    sE_vec = torch.FloatTensor(expert_samples).reshape(-1, d).to(device)
    state = env.reset()[0]
    env.set_state(qpos,qvel)
    state = env.state_vector()[1:]
    reward_fill = 0.0
    for t in range(T-time_step-1):
        a = sac_agent.get_action(state)
        next_state, _ ,_,_, _ = env.step(a)
        state = next_state
        reward_fill = reward_fill + reward_func(torch.FloatTensor(state))

    t1 = reward_func.r(sA_vec).view(-1) # E_q[r(tau)]
    t2 = reward_func.r(sE_vec).view(-1) # E_p[r(tau)]
    surrogate_objective = t1.mean() - 1.0*(t2.sum()+reward_fill)/T
    return T * surrogate_objective # same scale

def ML_sa_loss(div: str, agent_samples, expert_samples, reward_func, device):
    ''' NOTE: only for ML_sa: E_p[r(tau)] - E_q[r(tau)] w.r.t. r
        agent_samples is numpy array of shape (N, T, d) 
        expert_samples is numpy array of shape (N, T, d) or (N, d)
    '''
    #assert div in ['maxentirl']
    sA, aA, _ = agent_samples
    print(sA.shape,aA.shape)
    sA=np.concatenate([sA,aA],2)
    _, T, d = sA.shape

    sA_vec = torch.FloatTensor(sA).reshape(-1, d).to(device)
    sE_vec = torch.FloatTensor(expert_samples).reshape(-1, d).to(device)

    t1 = reward_func.r(sA_vec).view(-1) # E_q[r(tau)]
    t2 = reward_func.r(sE_vec).view(-1) # E_p[r(tau)]

    surrogate_objective = t1.mean() - t2.mean() # gradient ascent
    return T * surrogate_objective # same scale



def try_evaluate(expert_samples,samples,env,sac_agent,goal,itr: int, policy_type: str, sac_info):
    assert policy_type in ["Running"]
    update_time = itr * v['reward']['gradient_step']
    env_steps = itr * v['sac']['epochs'] * v['env']['T']
    agent_emp_states = samples[0].copy()
    assert agent_emp_states.shape[0] == v['irl']['training_trajs']

    metrics = eval.KL_summary(expert_samples, agent_emp_states.reshape(-1, agent_emp_states.shape[2]), env_steps, policy_type)
    # eval real reward
    real_return_det = eval.evaluate_real_return(v['env']['env_name'],goal,sac_agent.get_action, env, 
                                            v['irl']['eval_episodes'], v['env']['T'], True)
    metrics['Real Det Return'] = real_return_det
    print(f"real det return avg: {real_return_det:.2f}")
    logger.record_tabular("Real Det Return", round(real_return_det, 2))

    real_return_sto = eval.evaluate_real_return(v['env']['env_name'],goal,sac_agent.get_action, env, 
                                            v['irl']['eval_episodes'], v['env']['T'], False)
    metrics['Real Sto Return'] = real_return_sto
    print(f"real sto return avg: {real_return_sto:.2f}")
    logger.record_tabular("Real Sto Return", round(real_return_sto, 2))

    if v['obj'] in ["emd"]:
        eval_len = int(0.1 * len(critic_loss["main"]))
        emd = -np.array(critic_loss["main"][-eval_len:]).mean()
        metrics['emd'] = emd
        logger.record_tabular(f"{policy_type} EMD", emd)
    
    # plot_disc(v['obj'], log_folder, env_steps, 
    #     sac_info, critic_loss if v['obj'] in ["emd"] else disc_loss, metrics)

    logger.record_tabular(f"{policy_type} Update Time", update_time)
    logger.record_tabular(f"{policy_type} Env Steps", env_steps)

    return real_return_det, real_return_sto

def test_experiment(task_number):
    return_det_set=[]
    return_sto_set=[]
    yaml = YAML()
    v = yaml.load(open(sys.argv[1]))

    # common parameters
    env_name = v['env']['env_name']
    state_indices = v['env']['state_indices']
    seed = v['seed']
    num_expert_trajs = v['irl']['expert_episodes']

    # system: device, threads, seed, pid
    device = torch.device(f"cuda:{v['cuda']}" if torch.cuda.is_available() and v['cuda'] >= 0 else "cpu")
    torch.set_num_threads(1)
    np.set_printoptions(precision=3, suppress=True)
    system.reproduce(seed)
    pid=os.getpid()
    
    # assumptions
    assert v['obj'] in ['maxentirl','maxentirl_sa']
    assert v['IS'] == False

    # logs
    exp_id = f"logs/{env_name}/exp-{num_expert_trajs}/{v['obj']}" # task/obj/date structure
    # exp_id = 'debug'
    if not os.path.exists(exp_id):
        os.makedirs(exp_id)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    log_folder = exp_id + '/' + now.strftime('%Y_%m_%d_%H_%M_%S')
    logger.configure(dir=log_folder)            
    print(f"Logging to directory: {log_folder}")
    os.system(f'cp main/MERIT-IRL.py {log_folder}')
    os.system(f'cp {sys.argv[1]} {log_folder}/variant_{pid}.yml')
    with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f:
        json.dump(v, f, indent=2, sort_keys=True)
    print('pid', pid)
    os.makedirs(os.path.join(log_folder, 'plt'))
    os.makedirs(os.path.join(log_folder, 'model'))

    # environment
    env_fn = lambda: gym.make(env_name,ctrl_cost_weight=0.0)
    if env_name in ['Walker2d-v4','Hopper-v4']:
        env_fn = lambda: gym.make(env_name,ctrl_cost_weight=0.0,terminate_when_unhealthy=False)
    gym_env = env_fn()
    state_size = gym_env.observation_space.shape[0]
    action_size = gym_env.action_space.shape[0]
    if state_indices == 'all':
        state_indices = list(range(state_size))
    qpos_indices=list(range(len(gym_env.data.qpos)))
    qvel_indices=list(range(len(gym_env.data.qvel)))

    # load expert samples from trained policy
    expert_trajs = torch.load(f'expert_data/states/{env_name}/Task{task_number}_eval_set.pt').numpy()[:, :, state_indices]
    expert_trajs = expert_trajs[:num_expert_trajs, :, :] # select first expert_episodes
    expert_samples = expert_trajs.copy().reshape(-1, len(state_indices))
    expert_qpos_trajs = torch.load(f'expert_data/qpos/{env_name}/Task{task_number}_eval_set.pt').numpy()[:, :, qpos_indices]
    expert_qpos_trajs = expert_qpos_trajs[:num_expert_trajs, :, :] # select first expert_episodes
    expert_qpos_samples = expert_qpos_trajs.copy().reshape(-1, len(qpos_indices))
    expert_qvel_trajs = torch.load(f'expert_data/qvel/{env_name}/Task{task_number}_eval_set.pt').numpy()[:, :, qvel_indices]
    expert_qvel_trajs = expert_qvel_trajs[:num_expert_trajs, :, :] # select first expert_episodes
    expert_qvel_samples = expert_qvel_trajs.copy().reshape(-1, len(qvel_indices))
    #print(expert_qpos_samples)
    #print(expert_qpos_samples.shape)

    # Initilialize reward as a neural network
    
    reward_func = MLPReward(len(state_indices), **v['reward'], device=device).to(device)
    reward_func.load_state_dict(torch.load(f'{env_name}_meta_prior.pkl'))
    sa=False
    if v['obj']=='maxentirl_sa':
        sa=True
        reward_func = MLPReward(len(state_indices)+action_size, **v['reward'], device=device).to(device)
    reward_optimizer = torch.optim.Adam(reward_func.parameters(), lr=v['reward']['lr'], 
        weight_decay=v['reward']['weight_decay'], betas=(v['reward']['momentum'], 0.999))
    
    max_real_return_det, max_real_return_sto = -np.inf, -np.inf
    goal_CR=np.loadtxt('expert_data/states/'+str(env_name)+'/Task'+str(task_number)+'_goal.txt')
    goal=goal_CR[0]
    for itr in range(v['irl']['n_itrs']):
        if v['sac']['reinitialize'] or itr == 0:
            # Reset SAC agent with old policy, new environment, and new replay buffer
            print("Reinitializing sac")
            replay_buffer = ReplayBuffer(
                state_size, 
                action_size,
                device=device,
                size=v['sac']['buffer_size'])
                
            sac_agent = SAC(env_fn, replay_buffer,
                steps_per_epoch=v['env']['T'],
                update_after=v['env']['T'] * v['sac']['random_explore_episodes'], 
                max_ep_len=v['env']['T'],
                seed=seed,
                start_steps=v['env']['T'] * v['sac']['random_explore_episodes'],
                reward_state_indices=state_indices,
                device=device,
                sa=sa,
                **v['sac']
            )
        
        sac_agent.reward_function = reward_func.get_scalar_reward # only need to change reward in sac
        sac_info = sac_agent.learn_mujoco(print_out=False)

        start = time.time()
        samples = collect.collect_trajectories_policy_single(gym_env, sac_agent, 
                        n = v['irl']['training_trajs'], state_indices=state_indices)
        # Fit a density model using the samples
        agent_emp_states = samples[0].copy()
        agent_emp_states = agent_emp_states.reshape(-1,agent_emp_states.shape[2]) # n*T states
        print(f'collect trajs {time.time() - start:.0f}s', flush=True)
        # print(agent_emp_states.shape)

        # optimization w.r.t. reward
        losses = []
        for _ in range(v['reward']['gradient_step']):
            if v['irl']['resample_episodes'] > v['irl']['expert_episodes']:
                expert_res_indices = np.random.choice(expert_trajs.shape[0], v['irl']['resample_episodes'], replace=True)
                expert_trajs_train = expert_trajs[expert_res_indices].copy() # resampling the expert trajectories
            elif v['irl']['resample_episodes'] > 0:
                expert_res_indices = np.random.choice(expert_trajs.shape[0], v['irl']['resample_episodes'], replace=False)
                expert_trajs_train = expert_trajs[expert_res_indices].copy()
            else:
                expert_trajs_train = None # not use expert trajs

            if v['obj'] == 'maxentirl':
                expert_sample=expert_samples[0:itr+1]
                qpos = expert_qpos_samples[itr]
                qvel = expert_qvel_samples[itr]
                loss = ML_loss(v['obj'],gym_env, sac_agent, samples, expert_sample, qpos, qvel, reward_func, device, itr)
            elif v['obj'] == 'maxentirl_sa':
                loss = ML_sa_loss(v['obj'], samples, expert_samples_sa, reward_func, device) 
            
            losses.append(loss.item())
            print(f"{v['obj']} loss: {loss}")
            reward_optimizer.zero_grad()
            loss.backward()
            reward_optimizer.step()

        # evaluating the learned reward
        real_return_det, real_return_sto = try_evaluate(expert_samples,samples,env_fn(),sac_agent,goal,itr, "Running", sac_info)
        if real_return_det > max_real_return_det and real_return_sto > max_real_return_sto:
            max_real_return_det, max_real_return_sto = real_return_det, real_return_sto
            torch.save(reward_func.state_dict(), os.path.join(logger.get_dir(), 
                    f"model/reward_model_itr{itr}_det{max_real_return_det:.0f}_sto{max_real_return_sto:.0f}.pkl"))

        logger.record_tabular("Itration", itr)
        logger.record_tabular("Loss", loss.item())
        if v['sac']['automatic_alpha_tuning']:
            logger.record_tabular("alpha", sac_agent.alpha.item())
        
        if v['irl']['save_interval'] > 0 and (itr % v['irl']['save_interval'] == 0 or itr == v['irl']['n_itrs']-1):
            torch.save(reward_func.state_dict(), os.path.join(logger.get_dir(), f"model/reward_model_{itr}.pkl"))

        logger.dump_tabular()
        return_det_set.append(real_return_det)
        return_sto_set.append(real_return_sto)
    return return_det_set, return_sto_set

if __name__ == "__main__":
    yaml = YAML()
    v = yaml.load(open(sys.argv[1]))
    return_det_sets=[]
    return_sto_sets=[]
    for task_number in range(50,60):
        return_det_set,return_sto_set=test_experiment(task_number)
        return_det_sets.append(return_det_set)
        return_sto_sets.append(return_sto_set)
    return_det_mean=[]
    return_det_std=[]
    return_sto_mean=[]
    return_sto_std=[]
    tasks_number=len(return_det_sets)
    trajectory_length=len(return_det_sets[0])
    for t in range(trajectory_length):
        a=[]
        b=[]
        for task in range(tasks_number):
            a.append(return_det_sets[task][t])
            b.append(return_sto_sets[task][t])
        return_det_mean.append(np.mean(a))
        return_det_std.append(np.std(a))
        return_sto_mean.append(np.mean(b))
        return_sto_std.append(np.std(b))

    env_name = v['env']['env_name']
    np.savetxt("results/"+str(env_name)+"/MERIT_return_det_mean_file.txt",np.array(return_det_mean),delimiter =',')
    np.savetxt("results/"+str(env_name)+"/MERIT_return_det_std_file.txt",np.array(return_det_std),delimiter =',')
    np.savetxt("results/"+str(env_name)+"/MERIT_return_sto_mean_file.txt",np.array(return_sto_mean),delimiter =',')
    np.savetxt("results/"+str(env_name)+"/MERIT_return_sto_std_file.txt",np.array(return_sto_std),delimiter =',')




