from __future__ import annotations
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
import gym
from gym.envs.registration import register
import sys
sys.path.append('..')

import argparse
parser = argparse.ArgumentParser(description='seed lambda gamma')
parser.add_argument('--lam', type=float, help='lambda', required=True)
parser.add_argument('--lr', type=float, help='learning rate', required=True)
parser.add_argument('--seed', type=int, help='seed', required=True)
args = parser.parse_args()

########### register env ############
Pendulum_LEN = 500
register(
    id="InvPos-v0",
    entry_point="inverted_pendulum:InvertedPendulumEnv",
    max_episode_steps=Pendulum_LEN,
    reward_threshold=None,
    nondeterministic=False,
)
#####################################
class Policy_Network(nn.Module):
    def __init__(self, obs_space_dims: int, action_space_dims: int, max_action: float, hidden_dims: int):
        super().__init__()
        self.max_action = max_action

        # mean Network
        self.mean_net = nn.Sequential(
            nn.Linear(obs_space_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, action_space_dims),
            nn.Tanh(),
        )
        self.logstd = nn.Parameter(torch.zeros(1, action_space_dims))

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        action_means = self.mean_net(x.float()) * self.max_action
        action_stddevs = torch.exp(self.logstd)
        return action_means, action_stddevs

class Value_Network(nn.Module):
    def __init__(self, obs_space_dims: int, hidden_dims: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_space_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, hidden_dims),
            nn.ReLU(),
            nn.Linear(hidden_dims, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        values = self.net(x.float())
        return values

#########################################3
class REINFORCE:
    """REINFORCE algorithm."""

    def __init__(self, obs_space_dims: int, action_space_dims: int, max_action:float, hidden_dims: int, lr: float, lam: float):

        # Hyperparameters
        self.learning_rate = lr  # Learning rate for policy optimization
        self.lam = lam
        self.gamma = 0.999       # Discount factor
        self.eps = 1e-6          # small number for mathematical stability

        self.ep_probs = []       # Stores probability values of the sampled action, all trajectory
        self.ep_rewards = []     # Stores the corresponding rewards
        self.ep_states = []      # Stores the states of the sampled trajectory

        self.net = Policy_Network(obs_space_dims, action_space_dims, max_action, hidden_dims)
        self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=self.learning_rate)
        self.vf_net = Value_Network(obs_space_dims, hidden_dims)
        self.vf_optimizer = torch.optim.AdamW(self.vf_net.parameters(), lr=self.learning_rate*10)

    def sample_action(self, state: np.ndarray) -> tuple[float, torch.Tensor]:
        state = torch.tensor(np.array([state]))
        action_means, action_stddevs = self.net(state)

        distrib = Normal(action_means[0] + self.eps, action_stddevs[0])
        action = distrib.sample()
        log_prob = distrib.log_prob(action)

        action = action.numpy()

        return action, log_prob

    def put_data(self, probs, rewards, states):
        self.ep_probs.append(probs)
        self.ep_rewards.append(rewards)
        self.ep_states.append(states)

    def save_best_actor(self, save_path, risk=False):
        if risk:
            torch.save(self.net.state_dict(), save_path + 'best_actor.pth')
        else:
            torch.save(self.net.state_dict(), save_path + 'best_actor_risk.pth')

    def compute_gini_loss(self, ret, sum_log_prob):
        #print('ret:', ret.shape, 'sum_log_prob:', sum_log_prob.shape, sum_log_prob.requires_grad)
        sort_ret, indices = torch.sort(ret, descending=False)
        sort_sum_log_prob = sum_log_prob[indices]
        sample_size = sort_ret.shape[0]

        # compute integral CDF
        diff = sort_ret[1:] - sort_ret[:-1]
        x = torch.linspace(1., sample_size-1, sample_size-1)
        x /= sample_size
        diff = diff * x
        cumsum_diff = diff + torch.sum(diff) - torch.cumsum(diff, dim=-1)
        coef = 2. * cumsum_diff + sort_ret[:-1] - sort_ret[-1]
        gini_loss = -1. * sort_sum_log_prob[:-1] * coef

        return gini_loss

    def update(self):
        n_episodes = len(self.ep_states)

        ret_lst = []
        reinforce_loss_lst = []
        sum_log_prob_lst = []

        # prepair
        for ep_i in range(n_episodes):
            probs = self.ep_probs[ep_i]
            rewards = self.ep_rewards[ep_i]
            states = self.ep_states[ep_i]

            # turn rewards into return
            running_g = 0
            gs = []
            for R in rewards[::-1]:
                running_g = R + self.gamma * running_g
                gs.insert(0, running_g)
            ret_lst.append(running_g)

            # create tensors
            deltas = torch.tensor(gs, dtype=torch.float)                 # [ep_len]
            state_t = torch.tensor(np.array(states), dtype=torch.float)  # [ep_len, 4]
            log_prob_t = torch.stack(probs)                              # [ep_len] inveted pendulum logprob only has one dim
            #print('deltas:', deltas.shape, 'state_t:', state_t.shape, 'log_prob_t:', log_prob_t.shape, log_prob_t.requires_grad)
            
            # compute advantage
            vf_t = self.vf_net(state_t).squeeze(1)
            with torch.no_grad():
                advantage_t = deltas - vf_t                               # [ep_len]
            #print('adv_t:', advantage_t.shape)

            # reinforce baseline loss
            actor_loss = torch.sum( - log_prob_t * advantage_t)
            reinforce_loss_lst.append(actor_loss)
            sum_log_prob_lst.append(log_prob_t.sum())

            # compute value loss and update
            value_loss_fn = nn.MSELoss()
            vf_loss = value_loss_fn(vf_t, deltas)
            self.vf_optimizer.zero_grad()
            vf_loss.backward()
            self.vf_optimizer.step()

        # compute gini grad
        ret_t = torch.tensor(ret_lst, dtype=torch.float)                  # [n_ep] 
        sum_log_prob_t = torch.stack(sum_log_prob_lst)                    # [n_ep]
        gini_loss = self.compute_gini_loss(ret_t, sum_log_prob_t)         # [n_ep - 1]
        #print('ret_t:', ret_t.shape, 'sum_log_prob_t:', sum_log_prob_t.shape, 'gini_loss:', gini_loss.shape)

        # compute all loss
        mean_loss = torch.stack(reinforce_loss_lst)                       # [n_ep]
        #print('mean_loss:', mean_loss.shape)
        policy_loss = mean_loss.mean() + self.lam * gini_loss.mean()
        self.optimizer.zero_grad()
        policy_loss.backward()
        self.optimizer.step()
        
        '''empty buffer'''
        self.ep_probs = []
        self.ep_rewards = []
        self.ep_states = []

##################### functional #######################
def eval_model(env, agent, n_episodes=20):
    global Pendulum_LEN
    max_episode_length = Pendulum_LEN

    return_lst = []
    ep_length_lst = []
    xpos_positive_lst = []
    
    for _ in range(n_episodes):
        s, done = env.reset(), False
        ep_r, total_step, xpos_positive = 0, 0, 0
        while True:
            with torch.no_grad():
                a, _ = agent.sample_action(s)
            s_prime, r, done, info = env.step(a)
            
            xpos = info['x_position']
            if xpos > 0.01:
                xpos_positive += 1
            ep_r += r
            total_step += 1
            
            if total_step == max_episode_length:
                done = True
            if done:
                break

            s = s_prime

        return_lst.append(ep_r)
        ep_length_lst.append(total_step)
        xpos_positive_lst.append(xpos_positive)

    gd = 0
    for i in range(n_episodes):
        for j in range(n_episodes):
            gd += abs(return_lst[i] - return_lst[j])
    gd /= (2 * n_episodes * n_episodes)
        
    return np.array(return_lst), np.array(ep_length_lst), np.array(xpos_positive_lst), gd

def play_episode(env, agent):
    global Pendulum_LEN
    max_episode_length = Pendulum_LEN

    probs = []
    rewards = []
    states = []

    obs, done = env.reset(), False
    n_step = 0
    while not done:
        states.append(obs)
        action, log_prob = agent.sample_action(obs)
        obs, reward, done, info = env.step(action)
        probs.append(log_prob[0])
        rewards.append(reward)
        n_step += 1

        if n_step == max_episode_length:
            done = True

    agent.put_data(probs, rewards, states)

def get_save_dir(env_id, lam, lr_policy, seed):
    save_dir = './save/' + env_id + '/lam_' + str(lam) 
    save_dir += '/lr_p=' + str(lr_policy) + '/seed='+str(seed) + '/'
    return save_dir
#####################################################################3

# Create and wrap the environment
env_id = "InvPos-v0"
env = gym.make(env_id)
eval_env = gym.make(env_id)
obs_space_dims = env.observation_space.shape[0]
action_space_dims = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

total_num_epochs = 4000  # Total number of epochs
episodes_per_epoch = 30
eval_intvl = 20
lr = args.lr
lam = args.lam
hidden_dims = 128

# set seed
seed = args.seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
env.seed(seed)
eval_env.seed(2**31-1-seed)

print('state_dim:', obs_space_dims, 'action_dim:', action_space_dims, 'max_action:', max_action)
print('lam:', lam, 'lr:', lr, 'seed:', seed)
################################################

# Reinitialize agent every seed
agent = REINFORCE(obs_space_dims, action_space_dims, max_action, hidden_dims, lr, lam)

# create save dir
save_dir = get_save_dir(env_id, lam, lr, seed)
os.makedirs(save_dir, exist_ok=True)

eval_r_lst, eval_len_lst, eval_x_positive_lst = [],[],[]
best_mean, best_mean_gini = -10000, -10000
for epoch in range(total_num_epochs):

    for _ in range(episodes_per_epoch):
        play_episode(env, agent)
    agent.update()

    if (epoch+1) % eval_intvl == 0:
        #avg_reward = int(np.mean(wrapped_env.return_queue))
        #print("Episode:", episode, "Average Reward:", avg_reward)
        
        eval_r, eval_len, eval_x_positive, eval_gd = eval_model(eval_env, agent)
        print('eval return:', eval_r.mean())
        eval_r_lst.append(eval_r)
        eval_len_lst.append(eval_len)
        eval_x_positive_lst.append(eval_x_positive)
        with open(save_dir + 'eval_r.npy', 'wb') as f:
            np.save(f, np.array(eval_r_lst))
        with open(save_dir + 'eval_len.npy', 'wb') as f:
            np.save(f, np.array(eval_len_lst))
        with open(save_dir + 'eval_x_positive.npy', 'wb') as f:
            np.save(f, np.array(eval_x_positive_lst))

        curr_mean = eval_r.mean()
        curr_mean_gini = curr_mean - lam * eval_gd

        if curr_mean > best_mean:
            best_mean = curr_mean
            agent.save_best_actor(save_dir, False)

        if curr_mean_gini > best_mean_gini:
            best_mean_gini = curr_mean_gini
            agent.save_best_actor(save_dir, True) 
