from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.normal import Normal
import numpy as np
import os
import gym
from gym.envs.registration import register
import sys
sys.path.append('..')

import argparse
parser = argparse.ArgumentParser(description='seed lambda gamma')
parser.add_argument('--lam', type=float, help='lambda', required=True)
parser.add_argument('--lr', type=float, help='learning rate', required=True)
parser.add_argument('--seed', type=int, help='seed', required=True)
args = parser.parse_args()

######### register env #########
Pendulum_LEN = 500
register(
    id="InvPos-v0",
    entry_point="inverted_pendulum:InvertedPendulumEnv",
    max_episode_steps=Pendulum_LEN,
    reward_threshold=None,
    nondeterministic=False,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Policy_Network(nn.Module):
    def __init__(self, obs_space_dims: int, action_space_dims: int, max_action: float, hidden_dims: int):
        super().__init__()
        self.max_action = max_action

        # mean Network
        self.mean_net = nn.Sequential(
            nn.Linear(obs_space_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, action_space_dims),
            nn.Tanh(),
        )
        self.logstd = nn.Parameter(torch.zeros(1, action_space_dims))

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        action_means = self.mean_net(x.float()) * self.max_action
        action_stddevs = torch.exp(self.logstd)
        return action_means, action_stddevs

class MVPNet(nn.Module):
    def __init__(self, state_dim, action_dim, max_action, net_width):
        super(MVPNet, self).__init__()
        self.actor = Policy_Network(state_dim, action_dim, max_action, net_width)
        self.y = nn.Parameter(torch.zeros(1))
        self.eps = 1e-6

    def forward(self, state):
        action_means, action_stddevs = self.actor(state)
        dist = Normal(action_means[0]+self.eps, action_stddevs[0])
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return {'a':action, 'log_pi_a':log_prob}


class MVPAgent(object):
    def __init__(
            self,
            state_dim,
            action_dim,
            max_action,
            net_width,
            lr,
            lam
    ):
        self.network = MVPNet(state_dim, action_dim, max_action, net_width)
        self.y_optimizer = optim.RMSprop([self.network.y], lr=lr)
        self.policy_optimizer = optim.RMSprop(self.network.actor.parameters(), lr=lr)
        
        self.lam = lam
        
        self.reset()

    def save_best(self, save_path, risk=True):
        if risk:
            torch.save(self.network.state_dict(), save_path + 'net_var_best.th')
        else:
            torch.save(self.network.state_dict(), save_path + 'net_mean_best.th')

    def save_model(self, save_path, ep):
        torch.save(self.network.state_dict(), save_path + 'net_ep_'+str(ep)+'.th')

    def reset(self):
        self.total_rewards = torch.tensor(0, dtype=torch.float)
        self.log_pi_a = 0
        self.entropy = 0

    def pre_step(self,state):
        prediction = self.network(state)
        self.log_pi_a = self.log_pi_a + prediction['log_pi_a']
        #self.entropy = self.entropy + prediction['ent']
        return prediction['a'].item()

    def post_step(self, reward):
        self.total_rewards += reward

    def evaluate(self, state):
        with torch.no_grad():
            action_means, action_stddevs = self.network.actor(state)
            distrib = Normal(action_means[0], action_stddevs[0])
            action = distrib.sample()
            action = action.numpy()
        return action

    def compute_y_grad(self):
        y = self.network.y
        y_loss = (2* self.total_rewards + 1.0 / self.lam) * y - y.pow(2)
        y_loss = - y_loss
        self.y_optimizer.zero_grad()
        y_loss.backward()

    def compute_pi_grad(self):
        R = self.total_rewards
        y = self.network.y
        policy_loss = -(2 * y.detach() * R - R.pow(2)) * self.log_pi_a
        #policy_loss = policy_loss - self.entropy_coef * self.entropy
        self.policy_optimizer.zero_grad()
        policy_loss.backward()

    def train(self):
        # train at the end of the episode
        self.compute_y_grad()
        self.y_optimizer.step()

        self.compute_pi_grad()
        self.policy_optimizer.step()

        self.reset()

########## functional ###########
def eval_model(env, agent, n_episodes=20):
    global Pendulum_LEN
    max_episode_length = Pendulum_LEN

    return_lst = []
    ep_length_lst = []
    xpos_positive_lst = []
    
    for _ in range(n_episodes):
        s, done = env.reset(), False
        ep_r, total_step, xpos_positive = 0, 0, 0
        while True:
            state = np.array([s])
            a = agent.evaluate(torch.from_numpy(state).float().to(device))
            s_prime, r, done, info = env.step(a)
            
            xpos = info['x_position']
            if xpos > 0.01:
                xpos_positive += 1
            ep_r += r
            total_step += 1
            
            if total_step == max_episode_length:
                done = True
            if done:
                break

            s = s_prime

        return_lst.append(ep_r)
        ep_length_lst.append(total_step)
        xpos_positive_lst.append(xpos_positive)
        
    return np.array(return_lst), np.array(ep_length_lst), np.array(xpos_positive_lst)


def play_episode(env, agent):
    global Pendulum_LEN
    max_episode_length = Pendulum_LEN
    s = env.reset()
    episode_length = 0
    while True:
        state = np.array([s])
        a = agent.pre_step(torch.from_numpy(s).float().to(device))
        next_s, r, done, _ = env.step(a)
        episode_length += 1
        agent.post_step(r)

        if episode_length == max_episode_length:
            done = True
        if done:
            break

        s = next_s

def get_save_dir(lam, lr_policy, seed):
    save_dir = './save/lam_'+str(lam) + '/lr_p='+str(lr_policy)
    save_dir += '/seed_'+str(seed) + '/'
    return save_dir

########## setting ###########
env_id = "InvPos-v0"
env = gym.make(env_id)
eval_env = gym.make(env_id)
obs_space_dims = env.observation_space.shape[0]
action_space_dims = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

# set seed
seed = args.seed
env.seed(seed)
eval_env.seed(2**31-1-seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# hyperparameters
net_width = 128
lr = args.lr
lam = args.lam
train_episodes = 4000 * 30
test_intvl = 600
save_intvl = 50

print('lam:', lam, 'seed:', seed)
print('lr_p:', lr)

# create agent
kwargs = {
    "state_dim": obs_space_dims,
    "action_dim": action_space_dims,
    "max_action": max_action,
    "net_width": net_width,
    "lr": lr,
    "lam": lam,
}
agent = MVPAgent(**kwargs)

# create save dir
save_dir = get_save_dir(lam, lr, seed) 
os.makedirs(save_dir, exist_ok=True)

best_eval_return = -10000
best_eval_variance = -10000


eval_r_lst, eval_len_lst, eval_x_positive_lst = [],[],[]

for ep in range(train_episodes):
    play_episode(env, agent)
    agent.train()

    if (ep+1) % test_intvl == 0:
        eval_r, eval_len, eval_x_positive = eval_model(eval_env, agent)
        eval_r_lst.append(eval_r)
        eval_len_lst.append(eval_len)
        eval_x_positive_lst.append(eval_x_positive)

        eval_r_mean = eval_r.mean()
        print('eval return', eval_r_mean)

        with open(save_dir + 'eval_r.npy', 'wb') as f:
            np.save(f, np.array(eval_r_lst))
        with open(save_dir + 'eval_len.npy', 'wb') as f:
            np.save(f, np.array(eval_len_lst))
        with open(save_dir + 'eval_x_positive.npy', 'wb') as f:
            np.save(f, np.array(eval_x_positive_lst))


    