'''
f-IRL: Extract policy/reward from specified expert samples
'''
import sys, os, time
sys.path.append(
    os.path.dirname(os.path.dirname(__file__))
)


import numpy as np
import torch
import gym
from ruamel.yaml import YAML

from firl.divs.f_div_disc import f_div_disc_loss
from firl.divs.f_div import maxentirl_loss
from firl.divs.ipm import ipm_loss
from firl.models.reward import MLPReward
from firl.models.discrim import SMMIRLDisc as Disc
from firl.models.discrim import SMMIRLCritic as Critic
from common.sac import ReplayBuffer, SAC

import envs
from utils import system, collect, logger, eval
from utils.plots.train_plot_high_dim import plot_disc
from utils.plots.train_plot import plot_disc as visual_disc

import datetime
import dateutil.tz
import json, copy

def try_evaluate(itr: int, policy_type: str, sac_info):
    assert policy_type in ["Running"]
    update_time = itr * v['reward']['gradient_step']
    env_steps = itr * v['sac']['epochs'] * v['env']['T']
    agent_emp_states = antagonist_samples[0].copy()
    assert agent_emp_states.shape[0] == v['irl']['training_trajs']

    metrics = eval.KL_summary(expert_samples, agent_emp_states.reshape(-1, agent_emp_states.shape[2]), 
                         env_steps, policy_type)
    # eval real reward
    real_return_det = eval.evaluate_real_return(antagonist_sac_agent.get_action, env_fn(), 
                                            v['irl']['eval_episodes'], v['env']['T'], True)
    metrics['Real Det Return'] = real_return_det
    print(f"real det return avg: {real_return_det:.2f}")
    logger.record_tabular("Real Det Return", round(real_return_det, 2))

    real_return_sto = eval.evaluate_real_return(antagonist_sac_agent.get_action, env_fn(), 
                                            v['irl']['eval_episodes'], v['env']['T'], False)
    metrics['Real Sto Return'] = real_return_sto
    print(f"real sto return avg: {real_return_sto:.2f}")
    logger.record_tabular("Real Sto Return", round(real_return_sto, 2))

    if v['obj'] in ["emd"]:
        eval_len = int(0.1 * len(critic_loss["main"]))
        emd = -np.array(critic_loss["main"][-eval_len:]).mean()
        metrics['emd'] = emd
        logger.record_tabular(f"{policy_type} EMD", emd)
    
    # plot_disc(v['obj'], log_folder, env_steps, 
    #     sac_info, critic_loss if v['obj'] in ["emd"] else disc_loss, metrics)
    if "PointMaze" in env_name:
        visual_disc(agent_emp_states, reward_func.get_scalar_reward, disc.log_density_ratio, v['obj'],
                log_folder, env_steps, gym_env.range_lim,
                sac_info, disc_loss, metrics)

    logger.record_tabular(f"{policy_type} Update Time", update_time)
    logger.record_tabular(f"{policy_type} Env Steps", env_steps)

    return real_return_det, real_return_sto

def get_pagar_loss(protagonist_samples, antagonist_samples, protagonist_agent, antagonist_agent, reward_func, device, clip_param = 0.2):
    
    protagonist_s, protagonist_a, protagonist_log_a = protagonist_samples
    antagonist_s, antagonist_a, antagonist_log_a = antagonist_samples

     
    protagonist_N, protagonist_T, protagonist_s_d = protagonist_s.shape
    antagonist_N, antagonist_T, antagonist_s_d = antagonist_s.shape
   
    _, _, protagonist_a_d = protagonist_a.shape
    _, _, antagonist_a_d = antagonist_a.shape
 
    assert protagonist_s_d == antagonist_s_d
    assert protagonist_a_d == antagonist_a_d
 

    protagonist_s_vec = protagonist_s.reshape(-1, protagonist_s_d )
    antagonist_s_vec = antagonist_s.reshape(-1, antagonist_s_d)

    protagonist_a_vec = protagonist_a.reshape(-1, protagonist_a_d)
    antagonist_a_vec = antagonist_a.reshape(-1, antagonist_s_d)

    protagonist_log_a_vec = protagonist_log_a.reshape(-1, 1)
    antagonist_log_a_vec = antagonist_log_a.reshape(-1, 1)

    antagonist_protagonist_log_a_vec = protagonist_agent.ac.log_prob(torch.FloatTensor(antagonist_s_vec).to(device), torch.FloatTensor(antagonist_a_vec).to(device)).view(antagonist_N * antagonist_T, 1)
    protagonist_antagonist_log_a_vec = antagonist_agent.ac.log_prob(torch.FloatTensor(protagonist_s_vec).to(device), torch.FloatTensor(protagonist_a_vec).to(device)).view(protagonist_N * protagonist_T, 1)

    protagonist_r_vec = reward_func.r(torch.FloatTensor(protagonist_s_vec).to(device)).view(protagonist_N * protagonist_T, 1) # (N,)    
    antagonist_r_vec = reward_func.r(torch.FloatTensor(antagonist_s_vec).to(device)).view(antagonist_N * antagonist_T, 1) # (N,)    
    
    

    pair_r1 = - ((protagonist_antagonist_log_a_vec.exp() / protagonist_r_vec - protagonist_antagonist_log_a_vec.exp()).log())
    pair_ratio1 = torch.exp(protagonist_antagonist_log_a_vec - protagonist_log_a_vec).detach()
    pair_loss1 = (pair_r1 * pair_ratio1)
    pair_loss1 = pair_loss1[torch.isfinite(pair_loss1)].mean() 
    #pair_ids1 = (pair_ratio1 <=  1. + clip_param).float() * (pair_ratio1 >=  1. - clip_param).float()
    #pair_clipped_ratio1 = pair_ratio1 * pair_ids1
    #pair_loss1 = (pair_r1 * pair_clipped_ratio1)
    #pair_loss1 = pair_loss1[torch.isfinite(pair_loss1)].sum() / pair_ids1[torch.isfinite(pair_loss1)].sum()
    
        
    pair_kl1 = torch.nn.functional.mse_loss(protagonist_log_a_vec, protagonist_antagonist_log_a_vec).detach().item()
    #pair_kl1 = torch.sqrt(protagonist_actor(protagonist_states)[0] - antagonist_actor(protagonist_states)[0])
    #pair_kl1 = pair_kl1[torch.isfinite(pair_kl1)].max().detach().item()

    pair_loss1 =  pair_loss1 + pair_kl1 * 4 * 0.99 / (1 - 0.99) * torch.abs(pair_r1.flatten()).max() 
    pair_loss1 = pair_loss1 - pair_r1[torch.isfinite(pair_r1)].mean()
    
    
    pair_r2 = ((antagonist_log_a_vec.exp() / antagonist_r_vec - antagonist_log_a_vec.exp()).log())
    pair_ratio2 = (torch.exp(antagonist_protagonist_log_a_vec - antagonist_log_a_vec)).detach()
    pair_loss2 = (pair_r2 * pair_ratio2)
    pair_loss2 = pair_loss2[torch.isfinite(pair_loss2)].mean() 
    #pair_ids2 = (pair_ratio2 <=  1. + clip_param).float() * (pair_ratio2 >=  1. - clip_param).float()
    #pair_clipped_ratio2 = pair_ratio2 * pair_ids2
    #pair_loss2 = (pair_r2 * pair_clipped_ratio2)
    #pair_loss2 = pair_loss2[torch.isfinite(pair_loss2)].sum() / pair_ids2[torch.isfinite(pair_loss2)].sum()
    
    pair_kl2 = torch.nn.functional.mse_loss(antagonist_log_a_vec, antagonist_protagonist_log_a_vec).detach().item()
    #pair_kl2 = torch.sqrt(antagonist_actor(antagonist_states)[0] - protagonist_actor(antagonist_states)[0])
    #pair_kl2 = pair_kl2[torch.isfinite(pair_kl2)].max().detach().item()

    pair_loss2 = pair_loss2 - pair_kl2 * 4 * 0.99 / (1 - 0.99) * torch.abs(pair_r2.flatten()).max() 
    pair_loss2 = pair_loss2 - pair_r2[torch.isfinite(pair_r2)].mean() 

    
    pair_ratio3 = (torch.exp(pair_r2 - antagonist_log_a_vec.detach()))
    pair_ids3 = (pair_ratio3 <=  1. + clip_param).float() * (pair_ratio3 >=  1. - clip_param).float()
    pair_clipped_ratio3 = torch.clamp(pair_ratio3, 1 - clip_param, 1 + clip_param)# * pair_ids3.detach()
    pair_loss3 = - torch.min(pair_r2 * pair_ratio3, pair_r2 * pair_clipped_ratio3).mean()
    #pair_loss3 = pair_loss3[torch.isfinite(pair_loss3)].sum() / pair_ids3[torch.isfinite(pair_loss3)].sum()
    pair_loss3 = pair_loss3 - pair_r1[torch.isfinite(pair_r1)].mean()
    

    pair_ratio4 = (torch.exp(-pair_r1 - protagonist_log_a_vec.detach()))
    pair_ids3 = (pair_ratio3 <=  1. + clip_param).float() * (pair_ratio3 >=  1. - clip_param).float()
    pair_clipped_ratio4 = torch.clamp(pair_ratio4, 1 - clip_param, 1 + clip_param)# * pair_ids3.detach()
    pair_loss4 = -torch.min(-pair_r1 * pair_ratio4, -pair_r1 * pair_clipped_ratio4).mean()
    #pair_loss3 = pair_loss3[torch.isfinite(pair_loss3)].sum() / pair_ids3[torch.isfinite(pair_loss3)].sum()
    pair_loss4 = pair_loss4 - pair_r1[torch.isfinite(pair_r1)].mean()
    """
    pair_loss = pair_loss1 + pair_loss2
    
    pair_loss0 = 0
    
        
        protagonist_expert_r_vec, _, _ = reward_function(torch.cat((demonstration_states, protagonist_expert_actions[i]), dim = 1))
        protagonist_expert_r = (antagonist_protagonist_expert_log_a_vec[i].exp() / protagonist_expert_r_vec - antagonist_protagonist_expert_log_a_vec[i].exp()).log()

        antagonist_expert_r_vec, _, _ = reward_function(torch.cat((demonstration_states, antagonist_expert_actions[i]), dim = 1))
        antagonist_expert_r = (antagonist_expert_log_a_vec[i].exp() / antagonist_expert_r_vec - antagonist_expert_log_a_vec[i].exp()).log()
        
        pair_loss0_i = (protagonist_expert_log_a_vec[i].exp() * protagonist_expert_r - antagonist_expert_log_a_vec[i].exp() * antagonist_expert_r)
        pair_loss0_i = pair_loss0_i[torch.isfinite(pair_loss0_i)]
        pair_loss0 += pair_loss0_i.mean()
    pair_loss0 /= 20
    """

    #r = (antagonist_expert_r_vec /expert_r_vec - antagonist_expert_r_vec).log()
    #ratio = ((protagonist_expert_r_vec - antagonist_expert_r_vec).detach()) 
    #pair_loss0 =  (r * ratio)
    
    #ratio = (protagonist_expert_r_vec - antagonist_expert_r_vec).detach()  / r.exp().detach()
    #pair_loss0 =  (r * ratio) 
    #pair_loss0 = pair_loss0[pair_loss0 < 0]
    #pair_loss0 = pair_loss0[torch.isfinite(pair_loss0)].log().mean().exp()
    #print(pair_loss1, pair_loss2, pair_loss0)
    
    pair_loss = pair_loss1 + pair_loss2 + (pair_loss4 if torch.isfinite(pair_loss4).all() else 0.)  + (pair_loss3 if torch.isfinite(pair_loss3).all() else 0.) 
    
    return pair_loss
    


if __name__ == "__main__":
    yaml = YAML()
    v = yaml.load(open(sys.argv[1]))

    # common parameters
    env_name = v['env']['env_name']
    state_indices = v['env']['state_indices']
    seed = v['seed']
    num_expert_trajs = v['irl']['expert_episodes']

    # system: device, threads, seed, pid
    device = torch.device(f"cuda:{v['cuda']}" if torch.cuda.is_available() and v['cuda'] >= 0 else "cpu")
    torch.set_num_threads(1)
    np.set_printoptions(precision=3, suppress=True)
    system.reproduce(seed)
    pid=os.getpid()
    
    # assumptions
    assert v['obj'] in ['fkl', 'rkl', 'js', 'emd', 'maxentirl']
    assert v['IS'] == False

    # logs
    exp_id = f"logs/{env_name}/exp-{num_expert_trajs}/pagar_{v['obj']}" # task/obj/date structure
    # exp_id = 'debug'
    if not os.path.exists(exp_id):
        os.makedirs(exp_id)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    log_folder = exp_id + '/' + now.strftime('%Y_%m_%d_%H_%M_%S')
    logger.configure(dir=log_folder)            
    print(f"Logging to directory: {log_folder}")
    os.system(f'cp firl/pagar_samples.py {log_folder}')
    os.system(f'cp {sys.argv[1]} {log_folder}/variant_{pid}.yml')
    with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f:
        json.dump(v, f, indent=2, sort_keys=True)
    print('pid', pid)
    os.makedirs(os.path.join(log_folder, 'plt'))
    os.makedirs(os.path.join(log_folder, 'model'))

    # environment
    env_fn = lambda: gym.make(env_name)
    gym_env = env_fn()
    state_size = gym_env.observation_space.shape[0]
    action_size = gym_env.action_space.shape[0]
    if state_indices == 'all':
        state_indices = list(range(state_size))

    # load expert samples from trained policy
    expert_trajs = torch.load(f'expert_data/states/{env_name}.pt').numpy()[:, :, state_indices]
    expert_trajs = expert_trajs[:num_expert_trajs, :, :] # select first expert_episodes
    expert_samples = expert_trajs.copy().reshape(-1, len(state_indices))
    print(expert_trajs.shape, expert_samples.shape) # ignored starting state

    # Initilialize reward as a neural network
    reward_func = MLPReward(len(state_indices), **v['reward'], device=device).to(device)
    reward_optimizer = torch.optim.Adam(reward_func.parameters(), lr=v['reward']['lr'], 
        weight_decay=v['reward']['weight_decay'], betas=(v['reward']['momentum'], 0.999))
    
    # Initilialize discriminator
    if v['obj'] in ["emd"]:
        critic = Critic(len(state_indices), **v['critic'], device=device)
    elif v['obj'] != 'maxentirl':
        disc = Disc(len(state_indices), **v['disc'], device=device)

    max_real_return_det, max_real_return_sto = -np.inf, -np.inf
    for itr in range(v['irl']['n_itrs']):

        if v['sac']['reinitialize'] or itr == 0:
            # Reset SAC agent with old policy, new environment, and new replay buffer
            print("Reinitializing sac")
            replay_buffer = ReplayBuffer(
                state_size, 
                action_size,
                device=device,
                size=v['sac']['buffer_size'])
                
            protagonist_sac_agent = SAC(env_fn, replay_buffer,
                steps_per_epoch=v['env']['T'],
                update_after=v['env']['T'] * v['sac']['random_explore_episodes'], 
                max_ep_len=v['env']['T'],
                seed=seed,
                start_steps=v['env']['T'] * v['sac']['random_explore_episodes'],
                reward_state_indices=state_indices,
                device=device,
                **v['sac']
            )

            antagonist_sac_agent = SAC(env_fn, replay_buffer,
                steps_per_epoch=v['env']['T'],
                update_after=v['env']['T'] * v['sac']['random_explore_episodes'], 
                max_ep_len=v['env']['T'],
                seed=seed,
                start_steps=v['env']['T'] * v['sac']['random_explore_episodes'],
                reward_state_indices=state_indices,
                device=device,
                **v['sac']
            )
        
        protagonist_sac_agent.reward_function =  antagonist_sac_agent.reward_function = reward_func.get_scalar_reward # only need to change reward in sac
        print("Protagonist:")
        protagonist_sac_info = protagonist_sac_agent.learn_mujoco(print_out=True)
        print("Antagonist:")
        antagonist_sac_info = antagonist_sac_agent.learn_mujoco(print_out=True)

        start = time.time()
        protagonist_samples = collect.collect_trajectories_policy_single(gym_env, protagonist_sac_agent, 
                        n = v['irl']['training_trajs'], state_indices=state_indices)
        antagonist_samples = collect.collect_trajectories_policy_single(gym_env, antagonist_sac_agent, 
                        n = v['irl']['training_trajs'], state_indices=state_indices)
        # Fit a density model using the samples
        protagonist_agent_emp_states = protagonist_samples[0].copy()
        protagonist_agent_emp_states = protagonist_agent_emp_states.reshape(-1,protagonist_agent_emp_states.shape[2]) # n*T states
        print(f'collect trajs {time.time() - start:.0f}s', flush=True)

        antagonist_agent_emp_states = antagonist_samples[0].copy()
        antagonist_agent_emp_states = antagonist_agent_emp_states.reshape(-1,antagonist_agent_emp_states.shape[2]) # n*T states
        print(f'collect trajs {time.time() - start:.0f}s', flush=True)
        # print(agent_emp_states.shape)

        start = time.time()
        if v['obj'] in ["emd"]:
            critic_loss = critic.learn(expert_samples.copy(), antagonist_agent_emp_states, iter=v['critic']['iter'])
        elif v['obj'] != 'maxentirl': # learn log density ratio
            disc_loss = disc.learn(expert_samples.copy(), antagonist_agent_emp_states, iter=v['disc']['iter'])
        print(f'train disc {time.time() - start:.0f}s', flush=True)

        # optimization w.r.t. reward
        reward_losses = []
        for _ in range(v['reward']['gradient_step']):
            if v['irl']['resample_episodes'] > v['irl']['expert_episodes']:
                expert_res_indices = np.random.choice(expert_trajs.shape[0], v['irl']['resample_episodes'], replace=True)
                expert_trajs_train = expert_trajs[expert_res_indices].copy() # resampling the expert trajectories
            elif v['irl']['resample_episodes'] > 0:
                expert_res_indices = np.random.choice(expert_trajs.shape[0], v['irl']['resample_episodes'], replace=False)
                expert_trajs_train = expert_trajs[expert_res_indices].copy()
            else:
                expert_trajs_train = None # not use expert trajs

            if v['obj'] in ['fkl', 'rkl', 'js']:
                loss, _ = f_div_disc_loss(v['obj'], v['IS'], antagonist_samples, disc, reward_func, device, expert_trajs=expert_trajs_train)             
            elif v['obj'] in ['fkl-state', 'rkl-state', 'js-state']:
                loss = f_div_current_state_disc_loss(v['obj'], samples, disc, reward_func, device, expert_trajs=expert_trajs_train)
            elif v['obj'] == 'maxentirl':
                loss = maxentirl_loss(v['obj'], samples, expert_samples, reward_func, device)
            elif v['obj'] == 'emd':
                loss, _ = ipm_loss(v['obj'], v['IS'], samples, critic.value, reward_func, device, expert_trajs=expert_trajs_train)  
            

            pagar_loss = get_pagar_loss(protagonist_samples, antagonist_samples, protagonist_sac_agent, antagonist_sac_agent, reward_func, device)
            tot_loss = loss + pagar_loss * 1e-3

            reward_losses.append(tot_loss.item())
            print(f"{v['obj']}_loss: {loss}, pagar_{v['obj']}_loss: {pagar_loss}, total_loss: {tot_loss}")
            reward_optimizer.zero_grad()
            loss.backward()
            reward_optimizer.step()

        # evaluating the learned reward
        real_return_det, real_return_sto = try_evaluate(itr, "Running", antagonist_sac_info)
        if real_return_det > max_real_return_det and real_return_sto > max_real_return_sto:
            max_real_return_det, max_real_return_sto = real_return_det, real_return_sto
            torch.save(reward_func.state_dict(), os.path.join(logger.get_dir(), 
                    f"model/reward_model_itr{itr}_det{max_real_return_det:.0f}_sto{max_real_return_sto:.0f}.pkl"))

        logger.record_tabular("Itration", itr)
        logger.record_tabular("Reward Loss", loss.item())
        logger.record_tabular("PAGAR Loss", pagar_loss.item())
        if v['sac']['automatic_alpha_tuning']:
            logger.record_tabular("protagonist_alpha", protagonist_sac_agent.alpha.item())
            logger.record_tabular("antagonist_alpha", antagonist_sac_agent.alpha.item())

        # if v['irl']['save_interval'] > 0 and (itr % v['irl']['save_interval'] == 0 or itr == v['irl']['n_itrs']-1):
        #     torch.save(reward_func.state_dict(), os.path.join(logger.get_dir(), f"model/reward_model_{itr}.pkl"))

        logger.dump_tabular()