import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import deque, namedtuple
import random
import numpy as np
import os

import sys
sys.path.append('..')
from lunar_lander_risk import LunarLander

import argparse
parser = argparse.ArgumentParser(description='lambda, seed')
parser.add_argument('--lam', type=float, help='lambda', required=True)
parser.add_argument('--seed', type=int, help='seed', required=True)
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class QNetwork(nn.Module):
    def __init__(self, state_size, action_size, hidden_size):
        super(QNetwork, self).__init__()  
        self.fc1 = nn.Linear(state_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, action_size)
        
    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        q_vals = self.out(x)
        return q_vals 

class ReplayBuffer:
    def __init__(self, action_size, buffer_size, batch_size):
        self.action_size = action_size
        self.memory = deque(maxlen=buffer_size)  
        self.batch_size = batch_size
        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
    
    def add(self, state, action, reward, next_state, done):
        """Add a new experience to memory."""
        e = self.experience(state, action, reward, next_state, done)
        self.memory.append(e)
    
    def sample(self):
        """Randomly sample a batch of experiences from memory."""
        experiences = random.sample(self.memory, k=self.batch_size)

        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
        actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
        rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
        next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)
  
        return (states, actions, rewards, next_states, dones)

    def __len__(self):
        """Return the current size of internal memory."""
        return len(self.memory)

class MVPIAgent:
    def __init__(self, state_size, action_size, hidden_size, batch_size, lr, gamma, lam, buffer_size, update_every, target_tau):
        self.state_size = state_size
        self.action_size = action_size
        self.hidden_size = hidden_size
        self.update_every = update_every
        self.target_tau = target_tau

        self.gamma = gamma
        self.lam = lam

        self.warm_up = 1000
        
        # net
        self.qnet = QNetwork(state_size, action_size, hidden_size).to(device)
        self.qnet_target = QNetwork(state_size, action_size, hidden_size).to(device)
        self.hard_update(self.qnet, self.qnet_target)
        self.optimizer = optim.Adam(self.qnet.parameters(), lr=lr)
        # replay buffer
        self.memory = ReplayBuffer(action_size, buffer_size, batch_size)
        # online reward buffer
        self.online_reward = deque(maxlen=10000)
        # total time step
        self.t_step = 0

    def save_best(self, save_path, risk):
        if risk:
            torch.save(self.qnet.state_dict(), save_path + 'net_var_best.th')
        else:
            torch.save(self.qnet.state_dict(), save_path + 'net_mean_best.th')

    def save_model(self, save_path, ep):
        torch.save(self.qnet.state_dict(), save_path + 'net_ep_'+str(ep)+'.th')

    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)
        self.online_reward.append(reward)

        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step + 1) % self.update_every
        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > self.warm_up:
                experiences = self.memory.sample()
                self.learn(experiences, self.gamma, self.lam)


    def act(self, state, eps):
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        with torch.no_grad():
            qvals = self.qnet(state)
        # epsilon greedy
        if random.random() > eps:
            return np.argmax(qvals.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experiences, gamma, lam):
        states, actions, rewards, next_states, dones = experiences
        y = np.mean(self.online_reward)
        rewards = rewards - lam * rewards.pow(2) + 2 * lam * rewards * y

        Q_target_av = self.qnet_target(next_states).detach().max(1)[0].unsqueeze(1)
        Q_target = rewards + gamma * (Q_target_av) * (1-dones) # broadcasting works here.
        # compute the Q_expected 
        Q_expected = self.qnet(states).gather(1, actions)

        #compute loss
        loss = F.mse_loss(Q_expected, Q_target)
        self.optimizer.zero_grad()
        loss.backward() 
        self.optimizer.step()

        self.soft_update(self.qnet, self.qnet_target, self.target_tau)

    def hard_update(self, local_model, target_model):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(local_param.data)

    def soft_update(self, local_model, target_model, tau):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)

########## functional ########
def eval_model(env, agent, n_episodes):
    # used for check the final location
    # taken from lunarlander
    VIEWPORT_W = 600
    VIEWPORT_H = 400
    LEG_DOWN = 18
    SCALE = 30.0
    H = VIEWPORT_H/SCALE
    helipad_y  = H/4
    #-----------------------
    step_reward = []
    ep_return = []
    land_left = 0

    for i in range(n_episodes):
        total_reward = 0
        episode_length = 0
        s = env.reset()
        while True:
            a = agent.act(s, 0.01)
            s, r, done, info = env.step(a)
            step_reward.append(r)
            total_reward += r
            episode_length += 1

            # end episode early
            if episode_length == 1000:
                done = True
            if done:
                x = s[0] * (VIEWPORT_W/SCALE/2) + (VIEWPORT_W/SCALE/2)
                #y = s[1] * (VIEWPORT_H/SCALE/2) + (helipad_y+LEG_DOWN/SCALE)
                #if (x>=7.5 and x<=10) and (y >= 3.7 and y<= 4.3) and r == 100:
                if x<=10 and r == 100:
                    land_left += 1
                break

        ep_return.append(total_reward)

    var_R = np.array(step_reward).var()

    return np.array(ep_return), land_left, var_R

def get_save_dir(lam, noise, seed, lr_q):
    save_dir = './save/lam_'+str(lam) + '/noise_' + str(noise)
    save_dir += '/seed_'+str(seed) + '/lr_q='+str(lr_q) + '/'
    return save_dir

########## setting ##########
noise_scale = 90
env = LunarLander(noise_scale)
eval_env = LunarLander(noise_scale)
action_size = env.action_space.n
state_size = env.observation_space.shape[0]

# set seed
seed = args.seed
env.seed(seed)
eval_env.seed(2**31-1-seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# hyperparameters
hidden_size = 128
batch_size = 64
lr = 7e-4
gamma = 0.999
lam = args.lam
buffer_size = int(1e5)
update_every = 4
target_tau = 1e-3


train_episodes = 4000 * 30
eps_start = 1.0
eps_end = 0.02
eps = eps_start
eps_decay = 0.95
update_intvl = 4
eval_intvl = 60
eval_episodes = 10


######### main ##########

# create agent
# state_size, action_size, hidden_size, batch_size, lr, gamma, lam, buffer_size, update_every, target_tau
kwargs = {
    "state_size": state_size,
    "action_size": action_size,
    "hidden_size": hidden_size,
    "batch_size": batch_size,
    "lr": lr,
    "gamma": gamma,
    "lam": lam,
    "buffer_size": buffer_size,
    "update_every": update_every,
    "target_tau": target_tau,
}
agent = MVPIAgent(**kwargs)

# create save dir
save_dir = get_save_dir(lam, noise_scale, seed, lr)
os.makedirs(save_dir, exist_ok=True)



best_eval_return = -10000
best_eval_var = -10000
eval_rewards_list = []
eval_var_R_list = []
land_left_list = []

for n_epi in range(train_episodes):

    state = env.reset()
    for step in range(1000):
        action = agent.act(state, eps)
        next_state, reward, done, _ = env.step(action)
        agent.step(state, action, reward, next_state, done)

        if done:
            break
    
    # after episode
    eps = max(eps_end, eps_decay * eps)
    '''eval'''
    if (n_epi+1) % eval_intvl == 0:
        eval_r, land_left, var_R = eval_model(eval_env, agent, eval_episodes)
        eval_rewards_list.append(eval_r)
        land_left_list.append(land_left)
        eval_var_R_list.append(var_R)

        eval_r_mean = eval_r.mean()
        print('eval_r:', eval_r_mean)
        if eval_r_mean > best_eval_return:
            best_eval_return = eval_r_mean
            agent.save_best(save_dir, risk=False)

        mean_var = eval_r_mean - lam * var_R
        if mean_var > best_eval_var:
            best_eval_var = mean_var
            agent.save_best(save_dir, risk=True)

        with open(save_dir + 'eval_r.npy', 'wb') as f:
            np.save(f, np.array(eval_rewards_list))
        with open(save_dir + 'eval_land_left.npy', 'wb') as f:
            np.save(f, np.array(land_left_list))
        with open(save_dir + 'eval_var_R.npy', 'wb') as f:
            np.save(f, np.array(eval_var_R_list))
