import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

#Hyperparameters
learning_rate = 0.0005
gamma         = 0.9
lmbda         = 0.95
eps_clip      = 0.1
K_epoch       = 10
T_horizon     = 25
Update_episode = 20
save_freq = 10000


class PPO_MO(nn.Module):
    def __init__(self, state_dim, action_dim, device, n_rewards=2):
        super(PPO_MO, self).__init__()
        self.data = []
        self.n_rewards = n_rewards  # number of reward functions
        self.device = device
        # actor network
        self.fc_a1   = nn.Linear(state_dim, 64)
        self.lstm_a = nn.LSTM(64, 32, batch_first=True)
        
       
        self.fc_a2   = nn.Linear(32, 32)
        self.fc_a3 = nn.Linear(32, action_dim)
        
        # critic network
        self.fc_v1   = nn.Linear(state_dim, 64)
        self.lstm_v = nn.LSTM(64, 32, batch_first=True)
        self.fc_v2_list = nn.ModuleList([nn.Linear(32, 32) for _ in range(n_rewards)])
        self.fc_v3_list = nn.ModuleList([nn.Linear(32, 1) for _ in range(n_rewards)])
        
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.to(device)

    def pi(self, x, hidden):
        x = F.relu(self.fc_a1(x))
        if x.dim() == 1:
            x = x.view(-1, 1, 64)
        elif x.dim() == 2:
            x = x.unsqueeze(1)
        else:
            x = x
        x, lstm_hidden = self.lstm_a(x, hidden)
        x = F.relu(self.fc_a2(x))
        x = self.fc_a3(x)
        prob = F.softmax(x, dim=2)
        
        return prob, lstm_hidden
    
    def v_multi(self, x, hidden):
        x = F.relu(self.fc_v1(x))
        if x.dim() == 1:
            x = x.view(-1, 1, 64)
        elif x.dim() == 2:
            x = x.unsqueeze(1)
        x, lstm_hidden = self.lstm_v(x, hidden)
        
        values = []
        for i in range(self.n_rewards):
            h = F.relu(self.fc_v2_list[i](x))
            v = self.fc_v3_list[i](h)
            values.append(v)
        return values  
      
    def put_data(self, transition):
        self.data.append(transition)
        
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, h_in_lst, h_out_lst, done_lst = [], [], [], [], [], [], [], []
        for transition in self.data:
            s, a, r, s_prime, prob_a, h_in, h_out, done = transition
                
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append(r)
            s_prime_lst.append(s_prime)
            prob_a_lst.append([prob_a])
            h_in_lst.append(h_in)  
            h_out_lst.append(h_out)
            done_mask = 0 if done else 1
            done_lst.append([done_mask])
            
        s = torch.tensor(s_lst, dtype=torch.float).to(self.device)
        a = torch.tensor(a_lst).to(self.device)
        r = torch.tensor(r_lst).to(self.device)
        s_prime = torch.tensor(s_prime_lst, dtype=torch.float).to(self.device)
        done_mask = torch.tensor(done_lst, dtype=torch.float).to(self.device)
        prob_a = torch.tensor(prob_a_lst).to(self.device)
        
        h_in = (
            torch.cat([h[0] for h in h_in_lst], dim=1),
            torch.cat([h[1] for h in h_in_lst], dim=1)
        )
        h_out = (
            torch.cat([h[0] for h in h_out_lst], dim=1),
            torch.cat([h[1] for h in h_out_lst], dim=1)
        )
        
        self.data = []
        
        return s, a, r, s_prime, done_mask, prob_a, h_in, h_out
    
    def train_net(self):
        s, a, r, s_prime, done_mask, prob_a, (h1_in, h2_in), (h1_out, h2_out) = self.make_batch()
        h_in = (h1_in.detach().to(self.device), h2_in.detach().to(self.device))
        h_out = (h1_out.detach().to(self.device), h2_out.detach().to(self.device))

        loss_v_total = 0.0

        for i in range(K_epoch):
            v_primes = self.v_multi(s_prime, h_out)
            v_s_list = self.v_multi(s, h_in) # v_s_list shape: [n_rewards, B, 1, 1]

            # For each reward head
            loss_v = 0
            total_rewards = 0
            for j in range(self.n_rewards):
                r_j = r[:, j].unsqueeze(1)  # shape: [B, 1]
                rewards = []
                discounted_reward = 0
                for reward, is_terminal in zip(reversed(r_j.squeeze(1).cpu().numpy()), reversed(done_mask.cpu().numpy())):
                    if is_terminal == 0:
                        discounted_reward = 0
                    discounted_reward = reward + gamma * discounted_reward
                    rewards.insert(0, discounted_reward)
                rewards = torch.tensor(rewards, dtype=torch.float32).to(self.device)
                rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)
                rewards = rewards.unsqueeze(1)
                v_s = v_s_list[j].view(-1, 1)
                rewards = rewards.view(-1, 1)
                loss_v += F.mse_loss(v_s, rewards.detach())
                total_rewards += rewards

            pi, _ = self.pi(s, h_in)
            if a.dim() == 2:
                a = a.unsqueeze(1)
            pi_a = pi.gather(-1, a).squeeze(-1)
            ratio = pi_a / prob_a
            advantage_mean = total_rewards - sum([v.detach().squeeze(1) for v in v_s_list])
            surr1 = ratio * advantage_mean
            surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantage_mean
            loss_pi = -torch.min(surr1, surr2).mean()

            entropy = -torch.sum(pi * torch.log(pi + 1e-8), dim=-1).mean()
            loss = loss_pi + 0.5 * loss_v - 0.05 * entropy

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        return loss_pi.item(), loss_v.item()
    
    def save_model(self, path):
        checkpoint = {
            "model_state_dict": self.state_dict(),  
            "optimizer_state_dict": self.optimizer.state_dict(),  
        }
        torch.save(checkpoint, path)
        print(f"Model saved at {path}")

    def load_model(self, path):
        if not os.path.exists(path):
            print(f"No checkpoint found at {path}, skipping load.")
            return

        checkpoint = torch.load(path, map_location=self.device)  
        self.load_state_dict(checkpoint["model_state_dict"])  
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])  

