import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import os

import sys
sys.path.append('..')
from lunar_lander_risk import LunarLander

import argparse
parser = argparse.ArgumentParser(description='lambda, b, seed')
parser.add_argument('--lam', type=float, help='lambda', required=True)
parser.add_argument('--b', type=int, help='threshold', required=True)
parser.add_argument('--seed', type=int, help='seed', required=True)
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, net_width):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, net_width)
        self.l2 = nn.Linear(net_width, net_width)
        self.l3 = nn.Linear(net_width, action_dim)

    def forward(self, state):
        n = F.relu(self.l1(state))
        n = F.relu(self.l2(n))
        return n

    def pi(self, state, softmax_dim = 0):
        n = self.forward(state)
        prob = F.softmax(self.l3(n), dim=softmax_dim)
        return prob

class TamarNet(nn.Module):
    def __init__(self, state_dim, action_dim, net_width):
        super(TamarNet, self).__init__()
        self.actor = Actor(state_dim, action_dim, net_width)

        self.J = nn.Parameter(torch.zeros(1))
        self.V = nn.Parameter(torch.zeros(1))

        self.JV_params = [self.J, self.V]
        self.pi_params = list(self.actor.parameters())
        
    def forward(self, state):
        pi = self.actor.pi(state, softmax_dim=0)
        dist = Categorical(pi)
        action = dist.sample()
        log_prob = pi[action].log().unsqueeze(0)

        return {'a': action, 'log_pi_a': log_prob}

class TamarAgent(object):
    def __init__(
            self,
            state_dim,
            action_dim,
            net_width,
            lr,
            b,
            lam
    ):
        self.network = TamarNet(state_dim, action_dim, net_width)
        self.policy_optimizer = optim.RMSprop(self.network.pi_params, lr=lr)
        self.JV_optimizier = optim.SGD(self.network.JV_params, lr=lr * 100)

        self.b = b
        self.lam = lam
        self.reset()

    def save_best(self, save_path, risk=True):
        if risk:
            torch.save(self.network.state_dict(), save_path + 'net_var_best.th')
        else:
            torch.save(self.network.state_dict(), save_path + 'net_mean_best.th')

    def save_model(self, save_path, ep):
        torch.save(self.network.state_dict(), save_path + 'net_ep_'+str(ep)+'.th')

    def reset(self):
        self.total_rewards = torch.tensor(0, dtype=torch.float)
        self.log_pi_a = 0

    def pre_step(self, state):
        prediction = self.network(state)
        self.log_pi_a = self.log_pi_a + prediction['log_pi_a']
        return prediction['a'].item()

    def post_step(self, reward):
        self.total_rewards += reward

    def evaluate(self, state):
        with torch.no_grad():
            pi = self.network.actor.pi(state, softmax_dim=0)
            dist = Categorical(pi)
            action = dist.sample()
            action = action.item()
        return action

    def train(self):
        R = self.total_rewards
        J_loss = (R - self.network.J).pow(2).mul(0.5)
        V_loss = (R.pow(2) - self.network.J.detach().pow(2) - self.network.V).pow(2).mul(0.5)

        if self.network.V.cpu().detach().numpy().item() < self.b:
            grad = 0
        else:
            grad = 2 * (self.network.V.detach() - self.b)

        policy_loss = - ( R - self.lam * grad * (R.pow(2) - 2. * self.network.J.detach()) ) * self.log_pi_a

        self.JV_optimizier.zero_grad()
        self.policy_optimizer.zero_grad()

        (J_loss + V_loss + policy_loss).backward()

        self.JV_optimizier.step()
        self.policy_optimizer.step()

        self.reset()

########## functional ###########
def eval_model(env, agent, n_episodes):
    # used for check the final location
    # taken from lunarlander
    VIEWPORT_W = 600
    VIEWPORT_H = 400
    LEG_DOWN = 18
    SCALE = 30.0
    H = VIEWPORT_H/SCALE
    helipad_y  = H/4
    #-----------------------

    ep_return = []
    land_left = 0

    for i in range(n_episodes):
        total_reward = 0
        episode_length = 0
        s = env.reset()
        while True:
            a = agent.evaluate(torch.from_numpy(s).float().to(device))
            s, r, done, info = env.step(a)
            total_reward += r
            episode_length += 1

            # end episode early
            if episode_length == 1000:
                done = True
            if done:
                x = s[0] * (VIEWPORT_W/SCALE/2) + (VIEWPORT_W/SCALE/2)
                #y = s[1] * (VIEWPORT_H/SCALE/2) + (helipad_y+LEG_DOWN/SCALE)
                #if (x>=7.5 and x<=10) and (y >= 3.7 and y<= 4.3) and r == 100:
                if x<=10 and r == 100:
                    land_left += 1
                break

        ep_return.append(total_reward)

    return np.array(ep_return), land_left

def play_episode(env, agent):
    s = env.reset()
    episode_length = 0
    while True:
        a = agent.pre_step(torch.from_numpy(s).float().to(device))
        next_s, r, done, _ = env.step(a)
        episode_length += 1
        agent.post_step(r)

        if episode_length == 1000:
            done = True
        if done:
            break

        s = next_s

def get_save_dir(lam, b, seed, lr_policy):
    save_dir = './save/lam_'+str(lam) + '/b_'+str(b) + '/seed_'+str(seed)
    save_dir += '/lr_p='+str(lr_policy) + '/'
    return save_dir

########## setting ###########
noise_scale = 90
env = LunarLander(noise_scale)
eval_env = LunarLander(noise_scale)
action_size = env.action_space.n
state_size = env.observation_space.shape[0]

# set seed
seed = args.seed
env.seed(seed)
eval_env.seed(2**31-1-seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# hyperparameters
net_width = 128
lr = 7e-5
b = args.b
lam = args.lam
train_episodes = 2000
test_intvl = 20
save_intvl = 50
eval_episodes = 10

print('lam:', lam, 'b:', b, 'seed:', seed)
print('lr_p:', lr)

############# main ##############

# create agent
kwargs = {
    "state_dim": state_size,
    "action_dim": action_size,
    "net_width": net_width,
    "lr": lr,
    "b": b,
    "lam": lam,
}
agent = TamarAgent(**kwargs)

# create save dir
save_dir = get_save_dir(lam, b, seed, lr)
os.makedirs(save_dir, exist_ok=True)

best_eval_return = -10000
best_eval_variance = -10000

eval_rewards_list = []
land_left_list = []

for ep in range(train_episodes):
    play_episode(env, agent)
    agent.train()

    if (ep+1) % test_intvl == 0:
        eval_r, land_left = eval_model(eval_env, agent, eval_episodes)
        eval_rewards_list.append(eval_r)
        land_left_list.append(land_left)

        eval_r_mean = eval_r.mean()
        print('eval return', eval_r_mean)

        if eval_r_mean > best_eval_return:
            best_eval_return = eval_r_mean
            agent.save_best(save_dir, risk=False)

        mean_var = eval_r_mean - lam * eval_r.var()
        if mean_var > best_eval_variance:
            best_eval_variance = mean_var
            agent.save_best(save_dir, risk=True)

        with open(save_dir + 'eval_r.npy', 'wb') as f:
            np.save(f, np.array(eval_rewards_list))
        with open(save_dir + 'eval_land_left.npy', 'wb') as f:
            np.save(f, np.array(land_left_list))

    if (ep+1) % save_intvl == 0:
        agent.save_model(save_dir, ep+1)

    
