import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

#Hyperparameters
learning_rate = 0.0005
gamma         = 0.9
lmbda         = 0.95
eps_clip      = 0.1
K_epoch       = 10
T_horizon     = 25
Update_episode = 20
save_freq = 10000


class PPO_RNN(nn.Module):
    def __init__(self, state_dim, action_dim, device):
        super(PPO_RNN, self).__init__()
        self.data = []
        self.device = device
        # actor network
        self.fc_a1   = nn.Linear(state_dim, 64)
        self.lstm_a = nn.LSTM(64, 32, batch_first=True)
        self.fc_a2   = nn.Linear(32, 32)
        self.fc_a3 = nn.Linear(32, action_dim)
        
        # critic network
        self.fc_v1   = nn.Linear(state_dim, 64)
        self.fc_v2   = nn.Linear(32, 32)
        self.lstm_v = nn.LSTM(64, 32, batch_first=True)
        self.fc_v3  = nn.Linear(32, 1)
        
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.to(device)

    def pi(self, x, hidden):
        x = F.relu(self.fc_a1(x))
        if x.dim() == 1:
            x = x.view(-1, 1, 64)
        elif x.dim() == 2:
            x = x.unsqueeze(1)
        else:
            x = x
        x, lstm_hidden = self.lstm_a(x, hidden)
        x = F.relu(self.fc_a2(x))
        x = self.fc_a3(x)
        prob = F.softmax(x, dim=2)
        
        return prob, lstm_hidden
    
    def v(self, x, hidden):
        x = F.relu(self.fc_v1(x))
        if x.dim() == 1:
            x = x.view(-1, 1, 64)
        elif x.dim() == 2:
            x = x.unsqueeze(1)
        else:
            x = x
        x, lstm_hidden = self.lstm_v(x, hidden)
        x = F.relu(self.fc_v2(x))
        v = self.fc_v3(x)
        
        return v
      
    def put_data(self, transition):
        self.data.append(transition)
        
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, h_in_lst, h_out_lst, done_lst = [], [], [], [], [], [], [], []
        for transition in self.data:
            s, a, r, s_prime, prob_a, h_in, h_out, done = transition
                
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            prob_a_lst.append([prob_a])
            h_in_lst.append(h_in)  # h_in: tuple (h, c)
            h_out_lst.append(h_out)  # h_out: tuple (h, c)
            done_mask = 0 if done else 1
            done_lst.append([done_mask])
            
        # conver list to tensor
        s = torch.tensor(s_lst, dtype=torch.float).to(self.device)
        a = torch.tensor(a_lst).to(self.device)
        r = torch.tensor(r_lst).to(self.device)
        s_prime = torch.tensor(s_prime_lst, dtype=torch.float).to(self.device)
        done_mask = torch.tensor(done_lst, dtype=torch.float).to(self.device)
        prob_a = torch.tensor(prob_a_lst).to(self.device)

        # concatenate hidden states
        h_in = (
            torch.cat([h[0] for h in h_in_lst], dim=1),
            torch.cat([h[1] for h in h_in_lst], dim=1)
        )
        h_out = (
            torch.cat([h[0] for h in h_out_lst], dim=1),
            torch.cat([h[1] for h in h_out_lst], dim=1)
        )
        self.data = []
        
        return s, a, r, s_prime, done_mask, prob_a, h_in, h_out
    
    
    # with Monte Carlo returns to estimate advantage
    def train_net(self):
        # get a batch of data
        s, a, r, s_prime, done_mask, prob_a, (h1_in, h2_in), (h1_out, h2_out) = self.make_batch()
        first_hidden  = (h1_in.detach().to(self.device), h2_in.detach().to(self.device))
        second_hidden = (h1_out.detach().to(self.device), h2_out.detach().to(self.device))

        for i in range(K_epoch):
            # calculate value function
            v_prime = self.v(s_prime, second_hidden).squeeze(1)
            td_target = r + gamma * v_prime * done_mask
            v_s = self.v(s, first_hidden).squeeze(1)

            # calculate advantage using Monte Carlo returns
            rewards = []
            discounted_reward = 0
            for reward, is_terminal in zip(reversed(r.cpu().numpy()), reversed(done_mask.cpu().numpy())):
                if is_terminal == 0:  # if the episode is done, reset discounted reward
                    discounted_reward = 0
                discounted_reward = reward + gamma * discounted_reward
                rewards.insert(0, discounted_reward)  # insert at the beginning to maintain order
            rewards = torch.tensor(rewards, dtype=torch.float32).to(self.device)
            

            # normalize rewards
            rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)

            # calculate policy and action probabilities
            pi, _ = self.pi(s, first_hidden)
            if a.dim() == 2:
                a = a.unsqueeze(1) 
            pi_a = pi.gather(-1, a).squeeze(-1)

            # calculate the ratio of probabilities
            ratio = pi_a / prob_a

            # calculate surrogate loss
            advantage = rewards - v_s.detach() 
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantage
            loss_pi = -torch.min(surr1, surr2).mean()

            # calculate value loss
            loss_v = F.mse_loss(v_s, rewards)

            # calculate entropy for exploration
            entropy = -torch.sum(pi * torch.log(pi + 1e-8), dim=-1).mean()

            # final loss
            loss = loss_pi + 0.5 * loss_v - 0.05 * entropy  

            # update the network
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        return loss_pi, loss_v
    
    def save_model(self, path):
        checkpoint = {
            "model_state_dict": self.state_dict(),  
            "optimizer_state_dict": self.optimizer.state_dict(),  
        }
        torch.save(checkpoint, path)
        print(f"Model saved at {path}")

    def load_model(self, path):
        if not os.path.exists(path):
            print(f"No checkpoint found at {path}, skipping load.")
            return

        checkpoint = torch.load(path, map_location=self.device)  
        self.load_state_dict(checkpoint["model_state_dict"])  
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])  

