import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import gym
import numpy as np
import os

import sys
sys.path.append('..')
from lunar_lander_risk import LunarLander

import argparse
parser = argparse.ArgumentParser(description='seed lambda gamma')
parser.add_argument('--seed', type=int, help='seed', required=True)
parser.add_argument('--lam', type=float, help='lambda', required=True)
parser.add_argument('--gamma', type=float, help='gamma', required=True)
args = parser.parse_args()

# check and use GPU if available if not use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ActorNet(nn.Module):
    def __init__(self, state_size, action_size, hidden_size):
        super(ActorNet, self).__init__()
        self.dense_layer_1 = nn.Linear(state_size, hidden_size)
        self.dense_layer_2 = nn.Linear(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, action_size)
    
    def forward(self, x):
        x = F.relu(self.dense_layer_1(x))
        x = F.relu(self.dense_layer_2(x))
        return F.softmax(self.output(x),dim=-1) #-1 to take softmax of last dimension
    
class ValueFunctionNet(nn.Module):
    def __init__(self, state_size, hidden_size):
        super(ValueFunctionNet, self).__init__()
        self.dense_layer_1 = nn.Linear(state_size, hidden_size)
        self.dense_layer_2 = nn.Linear(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, 1)
    
    def forward(self, x):
        x = F.relu(self.dense_layer_1(x))
        x = F.relu(self.dense_layer_2(x))
        return self.output(x)

class PGAgent():
    def __init__(self, state_size, action_size, hidden_size, actor_lr, vf_lr, discount, n_episodes, lam):
        self.action_size = action_size
        self.actor_net = ActorNet(state_size, action_size, hidden_size).to(device)
        self.vf_net = ValueFunctionNet(state_size, hidden_size).to(device)
        self.actor_optimizer = optim.Adam(self.actor_net.parameters(), lr=actor_lr)
        self.vf_optimizer = optim.Adam(self.vf_net.parameters(), lr=vf_lr)
        self.discount = discount
        self.n_episodes = n_episodes
        self.lam = lam

        self.state_buf, self.action_buf, self.pi_a_buf, self.reward_buf = [], [], [], []

    def save_best(self, save_path, risk=True):
        if risk:
            torch.save(self.actor_net.state_dict(), save_path + 'actor_mean_gini.th')
            torch.save(self.vf_net.state_dict(), save_path + 'value_mean_gini.th')
        else:
            torch.save(self.actor_net.state_dict(), save_path + 'actor_mean.th')
            torch.save(self.vf_net.state_dict(), save_path + 'value_mean.th')

    def save_model(self, save_path, ep):
        torch.save(self.actor_net.state_dict(), save_path + 'actor_ep_'+str(ep)+'.th')
        torch.save(self.vf_net.state_dict(), save_path + 'value_ep_'+str(ep)+'.th')
        
    def select_action(self, state):
        #get action probs then randomly sample from the probabilities
        with torch.no_grad():
            input_state = torch.FloatTensor(state).to(device)
            action_probs = self.actor_net(input_state)
            #detach and turn to numpy to use with np.random.choice()
            action_probs = action_probs.detach().cpu().numpy()
            action = np.random.choice(np.arange(self.action_size), p=action_probs)
            action_prob = action_probs[action]
        return action, action_prob

    def put_data(self, state_lst, action_lst, action_prob_lst, reward_lst):
        self.state_buf.append(state_lst)
        self.action_buf.append(action_lst)
        self.pi_a_buf.append(action_prob_lst)
        self.reward_buf.append(reward_lst)

    def compute_gini_loss(self, ret, sum_log_pi, is_ratio):
        #print('sum_log_pi has grad? ', sum_log_pi.requires_grad, ' | shape', sum_log_pi.shape)
        sort_ret, indices = torch.sort(ret, descending=False)
        sort_sum_log_pi = sum_log_pi[indices]
        sort_is_ratio = is_ratio[indices]
        sample_size = sort_ret.shape[0]

        # compute integral CDF
        diff = sort_ret[1:] - sort_ret[:-1]
        x = torch.linspace(1., sample_size-1, sample_size-1)
        x /= sample_size
        diff = diff * x
        cumsum_diff = diff + torch.sum(diff) - torch.cumsum(diff, dim=-1)
        coef = 2. * cumsum_diff + sort_ret[:-1] - sort_ret[-1]

        gini_loss = -1 * sort_sum_log_pi[:-1] * coef * sort_is_ratio[:-1]
        return gini_loss

    def train_n(self, n):
        for i in range(n):
            stop = self.train()
            if stop:
                print('early stop', i+1)
                break
        '''clean buffer'''
        self.state_buf, self.action_buf, self.pi_a_buf, self.reward_buf = [], [], [], []

    def train(self):
        ret_list = []             # total return of each traj
        is_ratio_list = []        # importance ratio of each traj
        sum_log_pi_list = []      # sum_t log pi(a_t|s_t) of each traj
        reinforce_loss_list = []  # sum_t - log pi(a_t|s_t) A_t of each traj

        for i_e in range(self.n_episodes):
            state_list = self.state_buf[i_e]
            action_list = self.action_buf[i_e]
            pi_old_list = self.pi_a_buf[i_e]
            reward_list = self.reward_buf[i_e]
        
            #turn rewards into return
            trajectory_len = len(reward_list)
            return_array = np.zeros((trajectory_len,))
            g_return = 0.
            for i in range(trajectory_len-1,-1,-1):
                g_return = reward_list[i] + self.discount * g_return
                return_array[i] = g_return
            ret_list.append(g_return)
            
            # create tensors
            state_t = torch.FloatTensor(np.array(state_list)).to(device)
            action_t = torch.LongTensor(action_list).to(device).view(-1,1)
            return_t = torch.FloatTensor(return_array).to(device).view(-1,1)
            pi_old_t = torch.FloatTensor(pi_old_list).to(device).view(-1,1)  # shape [n, 1]
        
            # get value function estimates
            vf_t = self.vf_net(state_t).to(device)
            with torch.no_grad():
                advantage_t = return_t - vf_t
        
            # calculate actor loss
            selected_action_prob = self.actor_net(state_t).gather(1, action_t) # shape [n, 1]
            # REINFORCE loss:
            #actor_loss = torch.mean(-torch.log(selected_action_prob) * return_t)
            # REINFORCE Baseline loss:
            actor_loss = torch.sum(-torch.log(selected_action_prob) * advantage_t)
            reinforce_loss_list.append(actor_loss)
            sum_log_pi_list.append(torch.log(selected_action_prob).sum())
            # calculate importance sampling ratio
            log_ratio = torch.log(selected_action_prob).detach() - torch.log(pi_old_t)
            is_ratio_list.append(torch.exp(log_ratio.sum()).item() )

            # calculate vf loss, update value in the inner loop
            loss_fn = nn.MSELoss()
            vf_loss = loss_fn(vf_t, return_t)
            self.vf_optimizer.zero_grad()
            vf_loss.backward()
            self.vf_optimizer.step()

        '''choose IS ratio'''
        is_ratio = np.array(is_ratio_list)
        is_idx = np.where((is_ratio<=1.5) & (is_ratio>=0.5))
        is_ratio_choose = is_ratio[is_idx]
        choose_size = len(is_ratio_choose)

        if choose_size < self.n_episodes * 0.3:
            return True

        is_idx_t = torch.LongTensor(is_idx[0]).to(device)
        ret_t = torch.FloatTensor(ret_list).to(device)        # shape [n]
        reinforce_loss_t = torch.stack(reinforce_loss_list)   # shape [n]
        sum_log_pi_t = torch.stack(sum_log_pi_list)           # shape [n]
        # choose 
        is_ratio_choose = torch.FloatTensor(is_ratio_choose).to(device)
        ret_choose = ret_t.gather(0, is_idx_t)
        reinforce_loss_choose = reinforce_loss_t.gather(0, is_idx_t) # grad, shape [sample_size]
        sum_log_pi_choose = sum_log_pi_t.gather(0, is_idx_t)

        mean_loss = reinforce_loss_choose * is_ratio_choose
        gini_loss = self.compute_gini_loss(ret_choose, sum_log_pi_choose, is_ratio_choose)
        #print('gini', gini_loss.requires_grad)
        policy_loss = mean_loss.mean() + self.lam * gini_loss.mean()
        self.actor_optimizer.zero_grad()
        policy_loss.backward()
        self.actor_optimizer.step()

        '''sample size'''
        if choose_size < self.n_episodes * 0.6:
            return True
        else:
            return False

############# functional #############
def eval_model(env, agent, n_episodes=5):
    # -------- from LunarLander
    VIEWPORT_W = 600
    VIEWPORT_H = 400
    LEG_DOWN = 18
    SCALE = 30.0
    H = VIEWPORT_H/SCALE
    helipad_y  = H/4
    #-----------------------

    ep_return = []
    land_left = 0

    for i in range(n_episodes):
        total_reward = 0
        episode_length = 0
        s = env.reset()
        while True:
            a, _ = agent.select_action(s)
            s, r, done, info = env.step(a)
            total_reward += r
            episode_length += 1

            # end episode early
            if episode_length == 1000:
                done = True
            if done:
                # check if land at left
                x = s[0] * (VIEWPORT_W/SCALE/2) + (VIEWPORT_W/SCALE/2)
                #y = s[1] * (VIEWPORT_H/SCALE/2) + (helipad_y+LEG_DOWN/SCALE)
                if x<=10 and r == 100:
                    land_left += 1
                break

        ep_return.append(total_reward)

    gd = 0
    for i in range(n_episodes):
        for j in range(n_episodes):
            gd += abs(ep_return[i] - ep_return[j])
    gd /= (2 * n_episodes * n_episodes)

    return np.array(ep_return), gd, land_left

def play_episode(env, agent):
    state = env.reset()
    state_list, action_list, pi_a_list, reward_list = [], [], [], []
    episode_length = 0
    while True:
        action, action_prob = agent.select_action(state)
        next_state, reward, done, _ = env.step(action)
        episode_length += 1

        # store agent's trajectory
        state_list.append(state)
        action_list.append(action)
        pi_a_list.append(action_prob)
        reward_list.append(reward)

        # end episode early
        if episode_length == 1000:
            done = True

        if done:
            break

        state = next_state
    
    # store to buffer
    agent.put_data(state_list, action_list, pi_a_list, reward_list)

def get_save_dir(n_episodes, seed, lam, discount, noise, lr_policy, lr_value):
    save_dir = './save/bs_'+str(n_episodes)+'/seed_'+ str(seed) + '/lam_' + str(lam) +'/gamma_'+str(discount) + '/noise_'+str(noise)
    save_dir += '/lr_p=' + str(lr_policy) + '/lr_v=' + str(lr_value) + '/'
    return save_dir

################# setting ##################
noise_scale = 90
env = LunarLander(noise_scale)
eval_env = LunarLander(noise_scale)
action_size = env.action_space.n
state_size = env.observation_space.shape[0]

# set seed
seed = args.seed
env.seed(seed)
eval_env.seed(2**31-1-seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# hyperparameters
epochs = 4000             # run agent for this many epochs
hidden_size = 128         # number of units in NN hidden layers
actor_lr = 0.0007          # learning rate for actor
value_function_lr = 0.007  # learning rate for value function
discount = args.gamma     # discount factor gamma value
n_episodes = 30
lam = args.lam
inner_update = 10
eval_intvl = 20
save_model_intvl = 50
eval_episodes = 10

print('lam:', lam, 'discount:', discount,'noise:', noise_scale, 'seed', seed)
print('lr_p:', actor_lr, 'lr_v:', value_function_lr)

################ main ###############

# create agent
agent = PGAgent(state_size, action_size, hidden_size, actor_lr, value_function_lr, discount, n_episodes, lam)

eval_rewards_list = [] # store evaluation rewards
eval_land_left_list = []
best_eval_return = -10000
best_eval_mean_gd = -10000

# create save dir
save_dir = get_save_dir(n_episodes, seed, lam, discount, noise_scale, actor_lr, value_function_lr)
os.makedirs(save_dir, exist_ok=True)

for ep in range(epochs):
    print('epoch ', ep)
    for _ in range(n_episodes):
        play_episode(env, agent)
    agent.train_n(inner_update)

    if (ep+1) % eval_intvl == 0:
        eval_r, eval_gd, land_left = eval_model(eval_env, agent, eval_episodes)
        eval_rewards_list.append(eval_r)
        eval_land_left_list.append(land_left)
        eval_r_mean = eval_r.mean()
        print('eval return', eval_r_mean)

        if eval_r_mean > best_eval_return:
            best_eval_return = eval_r_mean
            agent.save_best(save_dir, risk=False)
        mean_gd = eval_r_mean - lam * eval_gd
        if mean_gd > best_eval_mean_gd:
            best_eval_mean_gd = mean_gd
            agent.save_best(save_dir, risk=True)
        
        with open(save_dir + 'eval_r.npy', 'wb') as f:
            np.save(f, np.array(eval_rewards_list))
        with open(save_dir + 'eval_land_left.npy', 'wb') as f:
            np.save(f, np.array(eval_land_left_list))

    if (ep+1) % save_model_intvl == 0:
        agent.save_model(save_dir, ep+1)

        


