from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.normal import Normal
import numpy as np
import os
import gym
from gym.envs.registration import register
import sys
sys.path.append('..')

import argparse
parser = argparse.ArgumentParser(description='seed lambda gamma')
parser.add_argument('--lam', type=float, help='lambda', required=True)
parser.add_argument('--lr', type=float, help='learning rate', required=True)
parser.add_argument('--b', type=float, help='b', required=True)
parser.add_argument('--seed', type=int, help='seed', required=True)
args = parser.parse_args()


######### register env #########
Pendulum_LEN = 500
register(
    id="InvPos-v0",
    entry_point="inverted_pendulum:InvertedPendulumEnv",
    max_episode_steps=Pendulum_LEN,
    reward_threshold=None,
    nondeterministic=False,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Policy_Network(nn.Module):
    def __init__(self, obs_space_dims: int, action_space_dims: int, max_action: float, hidden_dims: int):
        super().__init__()
        self.max_action = max_action

        # mean Network
        self.mean_net = nn.Sequential(
            nn.Linear(obs_space_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, action_space_dims),
            nn.Tanh(),
        )
        self.logstd = nn.Parameter(torch.zeros(1, action_space_dims))

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        action_means = self.mean_net(x.float()) * self.max_action
        action_stddevs = torch.exp(self.logstd)
        return action_means, action_stddevs

class TamarNet(nn.Module):
    def __init__(self, state_dim, action_dim, max_action, net_width):
        super(TamarNet, self).__init__()
        self.actor = Policy_Network(state_dim, action_dim, max_action, net_width)

        self.J = nn.Parameter(torch.zeros(1))
        self.V = nn.Parameter(torch.zeros(1))

        self.JV_params = [self.J, self.V]
        self.pi_params = list(self.actor.parameters())
        self.eps = 1e-6
        
    def forward(self, state):
        action_means, action_stddevs = self.actor(state)
        distrib = Normal(action_means[0]+self.eps, action_stddevs[0])
        action = distrib.sample()
        log_prob = distrib.log_prob(action)

        action = action.numpy()

        return {'a': action, 'log_pi_a': log_prob}

class TamarAgent(object):
    def __init__(
            self,
            state_dim,
            action_dim,
            max_action,
            net_width,
            lr,
            b,
            lam
    ):
        self.network = TamarNet(state_dim, action_dim, max_action, net_width)
        self.policy_optimizer = optim.RMSprop(self.network.pi_params, lr=lr)
        self.JV_optimizier = optim.SGD(self.network.JV_params, lr=lr * 100)

        self.b = b
        self.lam = lam
        self.reset()

    def save_best(self, save_path, risk=True):
        if risk:
            torch.save(self.network.state_dict(), save_path + 'net_var_best.th')
        else:
            torch.save(self.network.state_dict(), save_path + 'net_mean_best.th')

    def save_model(self, save_path, ep):
        torch.save(self.network.state_dict(), save_path + 'net_ep_'+str(ep)+'.th')

    def reset(self):
        self.total_rewards = torch.tensor(0, dtype=torch.float)
        self.log_pi_a = 0

    def pre_step(self, state):
        prediction = self.network(state)
        self.log_pi_a = self.log_pi_a + prediction['log_pi_a']
        return prediction['a'].item()

    def post_step(self, reward):
        self.total_rewards += reward

    def evaluate(self, state):
        with torch.no_grad():
            action_means, action_stddevs = self.network.actor(state)
            distrib = Normal(action_means[0], action_stddevs[0])
            action = distrib.sample()
            action = action.numpy()
        return action

    def train(self):
        R = self.total_rewards
        J_loss = (R - self.network.J).pow(2).mul(0.5)
        V_loss = (R.pow(2) - self.network.J.detach().pow(2) - self.network.V).pow(2).mul(0.5)

        if self.network.V.cpu().detach().numpy().item() < self.b:
            grad = 0
        else:
            grad = 2 * (self.network.V.detach() - self.b)

        policy_loss = - ( R - self.lam * grad * (R.pow(2) - 2. * self.network.J.detach()) ) * self.log_pi_a

        self.JV_optimizier.zero_grad()
        self.policy_optimizer.zero_grad()

        (J_loss + V_loss + policy_loss).backward()

        self.JV_optimizier.step()
        self.policy_optimizer.step()

        self.reset()

########## functional ###########
def eval_model(env, agent, n_episodes=20):
    global Pendulum_LEN
    max_episode_length = Pendulum_LEN

    return_lst = []
    ep_length_lst = []
    xpos_positive_lst = []
    
    for _ in range(n_episodes):
        s, done = env.reset(), False
        ep_r, total_step, xpos_positive = 0, 0, 0
        while True:
            state = np.array([s])
            a = agent.evaluate(torch.from_numpy(state).float().to(device))
            s_prime, r, done, info = env.step(a)
            
            xpos = info['x_position']
            if xpos > 0.01:
                xpos_positive += 1
            ep_r += r
            total_step += 1
            
            if total_step == max_episode_length:
                done = True
            if done:
                break

            s = s_prime

        return_lst.append(ep_r)
        ep_length_lst.append(total_step)
        xpos_positive_lst.append(xpos_positive)
        
    return np.array(return_lst), np.array(ep_length_lst), np.array(xpos_positive_lst)

def play_episode(env, agent):
    global Pendulum_LEN
    max_episode_length = Pendulum_LEN
    s = env.reset()
    episode_length = 0
    while True:
        state = np.array([s])
        a = agent.pre_step(torch.from_numpy(state).float().to(device))
        next_s, r, done, _ = env.step(a)
        episode_length += 1
        agent.post_step(r)

        if episode_length == max_episode_length:
            done = True
        if done:
            break

        s = next_s

def get_save_dir(lam, b, seed, lr_policy):
    save_dir = './save/lam_'+str(lam) + '/b_'+str(b) + '/seed_'+str(seed)
    save_dir += '/lr_p='+str(lr_policy) + '/'
    return save_dir

########## setting ###########
env_id = "InvPos-v0"
env = gym.make(env_id)
eval_env = gym.make(env_id)
obs_space_dims = env.observation_space.shape[0]
action_space_dims = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

# set seed
seed = args.seed
env.seed(seed)
eval_env.seed(2**31-1-seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# hyperparameters
net_width = 128
lr = args.lr
b = args.b
lam = args.lam
train_episodes = 4000 * 30
test_intvl = 600
save_intvl = 50


print('lam:', lam, 'b:', b, 'seed:', seed)
print('lr_p:', lr)

############# main ##############

# create agent
kwargs = {
    "state_dim": obs_space_dims,
    "action_dim": action_space_dims,
    "max_action": max_action,
    "net_width": net_width,
    "lr": lr,
    "b": b,
    "lam": lam,
}
agent = TamarAgent(**kwargs)

# create save dir
save_dir = get_save_dir(lam, b, seed, lr)
os.makedirs(save_dir, exist_ok=True)

best_eval_return = -10000
best_eval_variance = -10000

eval_r_lst, eval_len_lst, eval_x_positive_lst = [],[],[]

for ep in range(train_episodes):
    play_episode(env, agent)
    agent.train()

    if (ep+1) % test_intvl == 0:
        eval_r, eval_len, eval_x_positive = eval_model(eval_env, agent)
        eval_r_lst.append(eval_r)
        eval_len_lst.append(eval_len)
        eval_x_positive_lst.append(eval_x_positive)
        with open(save_dir + 'eval_r.npy', 'wb') as f:
            np.save(f, np.array(eval_r_lst))
        with open(save_dir + 'eval_len.npy', 'wb') as f:
            np.save(f, np.array(eval_len_lst))
        with open(save_dir + 'eval_x_positive.npy', 'wb') as f:
            np.save(f, np.array(eval_x_positive_lst))
    
