import sys, os, time
sys.path.append('./')
import numpy as np
import torch
import gym
from ruamel.yaml import YAML

from main.models.reward import MLPReward
from common.sac_UTILITY import ReplayBuffer, SAC

import envs
from utils import system, collect_shaping, logger, eval
from utils.plots.train_plot_high_dim import plot_disc
from utils.plots.train_plot import plot_disc as visual_disc

import datetime
import dateutil.tz
import json, copy

def ML_loss(env, policy, div: str, num_traj, agent_samples, reward_shaping, device, learned_reward):
    '''
        s_buffer is numpy array of shape (N, T, dimension of state) 
        a_buffer is numpy array of shape (N, T, d of action) 
        r_buffer is numpy array of shape (N, T) 
        gradient_buffer is numpy array of shape (N, T) 
    '''
    assert div in ['maxentirl']
    s_buffer, a_buffer , r_buffer, s_buffer_extra, a_buffer_extra, r_buffer_extra, Qpos_buffer, Qvel_buffer = agent_samples   #, gradient_buffer
    _, T, d = s_buffer.shape
    s_buffer, r_buffer, s_buffer_extra, r_buffer_extra = torch.FloatTensor(s_buffer), torch.FloatTensor(r_buffer), torch.FloatTensor(s_buffer_extra), torch.FloatTensor(r_buffer_extra)# , gradient_buffer , torch.FloatTensor(gradient_buffer)

    q_value = torch.zeros(num_traj,T)
    q_value_shaping = torch.zeros(num_traj,T)
    v_value = torch.zeros(num_traj,T)
    v_value_shaping = torch.zeros(num_traj,T)
 
    num_traj_v = 4

    for traj_no in range(num_traj):
        backup = r_buffer_extra[traj_no,:].sum()
        s_extra = s_buffer_extra[traj_no,:,:].to(device)
        shaping_reward_extra = r_buffer_extra[traj_no,:]+reward_shaping(s_extra)
        #shaping_reward_extra = r_buffer_extra[traj_no,:]+(r_buffer_extra[traj_no,:]-learned_reward(s_extra))*reward_shaping(s_extra)
        backup_shaping = shaping_reward_extra.sum()
        for t in range(T):
            #print('time index', t)
            q_value[traj_no,T-t-1] = r_buffer[traj_no,T-t-1]+backup
            q_value_shaping[traj_no,T-t-1] = (r_buffer[traj_no,T-t-1]+reward_shaping(s_buffer[traj_no,T-t-1,:])).view(-1) + backup_shaping
            #q_value_shaping[traj_no,T-t-1] = (r_buffer[traj_no,T-t-1]+(r_buffer[traj_no,T-t-1]-learned_reward(s_buffer[traj_no,T-t-1,:]))*reward_shaping(s_buffer[traj_no,T-t-1,:])).view(-1) + backup_shaping
            backup = q_value[traj_no,T-t-1]-r_buffer_extra[traj_no,T-t-1]
            backup_shaping = q_value_shaping[traj_no,T-t-1]-(r_buffer_extra[traj_no,T-t-1]+reward_shaping(s_buffer_extra[traj_no,T-t-1,:])).view(-1)
            #backup_shaping = q_value_shaping[traj_no,T-t-1]-(r_buffer_extra[traj_no,T-t-1]+(r_buffer_extra[traj_no,T-t-1]-learned_reward(s_buffer_extra[traj_no,T-t-1,:]))*reward_shaping(s_buffer_extra[traj_no,T-t-1,:])).view(-1)
            # Here we compute value functions
            for trajectory_number in range(num_traj_v):
                s = env.reset()[0]
                env.set_state(Qpos_buffer[traj_no,T-t-1,:],Qvel_buffer[traj_no,T-t-1,:])
                s = env.state_vector()[1:]
                if v['env']['env_name'] in ['Ant-v4','Ant-v3']:
                    s = env.state_vector()[2:]
                for time_index in range(T):
                    a = policy.get_action(s)
                    s_nxt, r, _, _, _ = env.step(a) # assign reward online
                    v_value[traj_no,T-t-1] = v_value[traj_no,T-t-1] + r
                    v_value_shaping[traj_no,T-t-1] = v_value_shaping[traj_no,T-t-1] + r + reward_shaping(torch.FloatTensor(s))
                    #v_value_shaping[traj_no,T-t-1] = v_value_shaping[traj_no,T-t-1] + r + (r-learned_reward(torch.FloatTensor(s)))*reward_shaping(torch.FloatTensor(s))
                    s = s_nxt
            v_value[traj_no,T-t-1] = v_value[traj_no,T-t-1]/num_traj_v
            v_value_shaping[traj_no,T-t-1] = v_value_shaping[traj_no,T-t-1]/num_traj_v
    advantage = q_value- v_value
    advantage_shaping = q_value_shaping - v_value_shaping
    advantage = (advantage - advantage.mean())/advantage.std()
    advantage_shaping = (advantage_shaping-advantage_shaping.mean())/advantage_shaping.std()
    surrogate_list = - advantage*advantage_shaping 
    #print('q_value',q_value.mean())
    #print('q_value_shaping',q_value_shaping.mean())
    #print('gradient log pi',gradient_buffer.mean())
    surrogate_loss = surrogate_list.mean()

    return surrogate_loss # same scale

def ML_sa_loss(div: str, agent_samples, reward_func, device):
    ''' NOTE: only for ML_sa: E_p[r(tau)] - E_q[r(tau)] w.r.t. r
        agent_samples is numpy array of shape (N, T, d) 
        expert_samples is numpy array of shape (N, T, d) or (N, d)
    '''
    #assert div in ['maxentirl']
    sA, aA, _ = agent_samples
    print(sA.shape,aA.shape)
    sA=np.concatenate([sA,aA],2)
    _, T, d = sA.shape

    sA_vec = torch.FloatTensor(sA).reshape(-1, d).to(device)

    t1 = reward_func.r(sA_vec).view(-1) # E_q[r(tau)]
    t2 = reward_func.r(sE_vec).view(-1) # E_p[r(tau)]

    surrogate_objective = t1.mean() - t2.mean() # gradient ascent
    return T * surrogate_objective # same scale



def try_evaluate(samples, itr: int, policy_type: str, sac_info):
    assert policy_type in ["Running"]
    update_time = itr * v['reward']['gradient_step']
    env_steps = itr * v['sac']['epochs'] * v['env']['T']
    agent_emp_states = samples[0].copy()
    assert agent_emp_states.shape[0] == v['irl']['training_trajs']

    metrics = {'evaluation': 1}
    # eval real reward
    real_return_det_mean, real_return_det_std = eval.evaluate_real_return(sac_agent.get_action, env_fn(), v['irl']['eval_episodes'], v['env']['T'], True)
    metrics['Real Det Return'] = real_return_det_mean
    print(f"real det return avg: {real_return_det_mean:.2f}")
    logger.record_tabular("Real Det Return", round(real_return_det_mean, 2))

    real_return_sto_mean, real_return_sto_std = eval.evaluate_real_return(sac_agent.get_action, env_fn(), v['irl']['eval_episodes'], v['env']['T'], False)
    metrics['Real Sto Return'] = real_return_sto_mean
    print(f"real sto return avg: {real_return_sto_mean:.2f}")
    logger.record_tabular("Real Sto Return", round(real_return_sto_mean, 2))
    
    # plot_disc(v['obj'], log_folder, env_steps, 
    #     sac_info, critic_loss if v['obj'] in ["emd"] else disc_loss, metrics)

    logger.record_tabular(f"{policy_type} Update Time", update_time)
    logger.record_tabular(f"{policy_type} Env Steps", env_steps)

    return real_return_det_mean, real_return_det_std, real_return_sto_mean, real_return_sto_std

if __name__ == "__main__":
    yaml = YAML()
    v = yaml.load(open(sys.argv[1]))

    # common parameters
    env_name = v['env']['env_name']
    state_indices = v['env']['state_indices']
    seed = v['seed']
    delayed = v['env']['delayed']
    cumulative_reward_det_mean = []
    cumulative_reward_det_std = []
    cumulative_reward_sto_mean = []
    cumulative_reward_sto_std = []
    # system: device, threads, seed, pid
    device = torch.device(f"cuda:{v['cuda']}" if torch.cuda.is_available() and v['cuda'] >= 0 else "cpu")
    torch.set_num_threads(1)
    np.set_printoptions(precision=3, suppress=True)
    system.reproduce(seed)
    pid=os.getpid()
    
    # assumptions
    assert v['obj'] in ['maxentirl','maxentirl_sa']
    assert v['IS'] == False

    # logs
    exp_id = f"logs/{env_name}/UTILITY" # task/obj/date structure
    # exp_id = 'debug'
    if not os.path.exists(exp_id):
        os.makedirs(exp_id)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    log_folder = exp_id + '/' + now.strftime('%Y_%m_%d_%H_%M_%S')
    logger.configure(dir=log_folder)            
    print(f"Logging to directory: {log_folder}")
    #os.system(f'cp ml/irl_samples.py {log_folder}')
    #os.system(f'cp {sys.argv[1]} {log_folder}/variant_{pid}.yml')
    #with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f:
    #    json.dump(v, f, indent=2, sort_keys=True)
    #print('pid', pid)
    #os.makedirs(os.path.join(log_folder, 'plt'))
    os.makedirs(os.path.join(log_folder, 'model'))

    # environment
    env_fn = lambda: gym.make(env_name)
    gym_env = env_fn()
    state_size = gym_env.observation_space.shape[0]
    action_size = gym_env.action_space.shape[0]
    if state_indices == 'all':
        state_indices = list(range(state_size))

    # Initilialize reward as a neural network
    
    reward_shaping = MLPReward(len(state_indices), **v['reward'], device=device).to(device)
    sa=False
    if v['obj']=='maxentirl_sa':
        sa=True
        reward_shaping = MLPReward(len(state_indices)+action_size, **v['reward'], device=device).to(device)
    reward_optimizer = torch.optim.Adam(reward_shaping.parameters(), lr=v['reward']['lr'], weight_decay=v['reward']['weight_decay'], betas=(v['reward']['momentum'], 0.999))
    
    max_real_return_det, max_real_return_sto = -np.inf, -np.inf
    learned_reward = MLPReward(len(state_indices), **v['reward'], device=device).to(device)
    if delayed:
        learned_reward.load_state_dict(torch.load(f'delayed_{env_name}_learned_reward_model.pkl'))
    else:
        learned_reward.load_state_dict(torch.load(f'original_{env_name}_learned_reward_model.pkl'))
    #misleading_points = np.loadtxt(str(env_name)+"_misleading_points.txt")
    for itr in range(v['irl']['n_itrs']):
        if v['sac']['reinitialize'] or itr == 0:
            # Reset SAC agent with old policy, new environment, and new replay buffer
            print("Reinitializing sac")
            replay_buffer = ReplayBuffer(
                state_size, 
                action_size,
                device=device,
                size=v['sac']['buffer_size'])
                
            sac_agent = SAC(env_fn, replay_buffer,
                steps_per_epoch=v['env']['T'],
                update_after=v['env']['T'] * v['sac']['random_explore_episodes'], 
                max_ep_len=v['env']['T'],
                seed=seed,
                start_steps=v['env']['T'] * v['sac']['random_explore_episodes'],
                reward_state_indices=state_indices,
                device=device,
                delayed=v['env']['delayed'],
                sa=sa,
                learned_reward=learned_reward,
                **v['sac']
            )
        
        sac_agent.reward_function = reward_shaping.get_scalar_reward # only need to change reward in sac
        sac_info = sac_agent.learn_mujoco(print_out=True)

        start = time.time()

        samples = collect_shaping.collect_trajectories_policy_single(gym_env, sac_agent, n = v['irl']['training_trajs'], state_indices=state_indices)
        # optimization w.r.t. reward
        reward_losses = []
        for _ in range(v['reward']['gradient_step']):
            if v['obj'] == 'maxentirl':
                loss = ML_loss(gym_env, sac_agent, v['obj'], v['irl']['training_trajs'], samples, reward_shaping, device, learned_reward)
            elif v['obj'] == 'maxentirl_sa':
                loss = ML_sa_loss(v['obj'], samples, reward_shaping, device) 
            
            reward_losses.append(loss.item())
            print(f"{v['obj']} loss: {loss}")
            reward_optimizer.zero_grad()
            loss.backward()
            reward_optimizer.step()

        # evaluating the learned reward
        real_return_det_mean, real_return_det_std, real_return_sto_mean, real_return_sto_std = try_evaluate(samples, itr, "Running", sac_info)
        cumulative_reward_det_mean.append(real_return_det_mean)
        cumulative_reward_det_std.append(real_return_det_std)
        cumulative_reward_sto_mean.append(real_return_sto_mean)
        cumulative_reward_sto_std.append(real_return_sto_std)        
        if real_return_det_mean > max_real_return_det and real_return_sto_mean > max_real_return_sto:
            max_real_return_det, max_real_return_sto = real_return_det_mean, real_return_sto_mean
            torch.save(reward_shaping.state_dict(), os.path.join(logger.get_dir(), 
                    f"model/reward_model_itr{itr}_det{max_real_return_det:.0f}_sto{max_real_return_sto:.0f}.pkl"))

        logger.record_tabular("Itration", itr)
        logger.record_tabular("Reward Loss", loss.item())
        if v['sac']['automatic_alpha_tuning']:
            logger.record_tabular("alpha", sac_agent.alpha.item())
        
        if itr == v['irl']['n_itrs']-1:
            torch.save(reward_shaping.state_dict(), os.path.join(logger.get_dir(), f"learned_reward_UTILITY.pkl"))

        logger.dump_tabular()
    if v['env']['delayed']:
        np.savetxt("results/delayed_"+str(v['env']['env_name'])+"/UTILITY/cumulative_reward_det_mean.txt",np.array(cumulative_reward_det_mean),delimiter = ',')
        np.savetxt("results/delayed_"+str(v['env']['env_name'])+"/UTILITY/cumulative_reward_det_std.txt",np.array(cumulative_reward_det_std),delimiter = ',')
        np.savetxt("results/delayed_"+str(v['env']['env_name'])+"/UTILITY/cumulative_reward_sto_mean.txt",np.array(cumulative_reward_sto_mean),delimiter = ',')
        np.savetxt("results/delayed_"+str(v['env']['env_name'])+"/UTILITY/cumulative_reward_sto_std.txt",np.array(cumulative_reward_sto_std),delimiter = ',')
    else:
        np.savetxt("results/original_"+str(v['env']['env_name'])+"/UTILITY/cumulative_reward_det_mean.txt",np.array(cumulative_reward_det_mean),delimiter = ',')
        np.savetxt("results/original_"+str(v['env']['env_name'])+"/UTILITY/cumulative_reward_det_std.txt",np.array(cumulative_reward_det_std),delimiter = ',')
        np.savetxt("results/original_"+str(v['env']['env_name'])+"/UTILITY/cumulative_reward_sto_mean.txt",np.array(cumulative_reward_sto_mean),delimiter = ',')
        np.savetxt("results/original_"+str(v['env']['env_name'])+"/UTILITY/cumulative_reward_sto_std.txt",np.array(cumulative_reward_sto_std),delimiter = ',')
        








