import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import torch.nn.functional as F
import os
"""Replay Buffer Design"""
class ReplayBuffer(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward_pri, reward_sec, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward_pri, reward_sec, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward_pri, reward_sec, next_state, done = zip(*batch)
        return state, action, reward_pri, reward_sec, next_state, done

    def __len__(self):
        return len(self.buffer)
" The PAN Network"
class PointArrayFeatureExtractor(nn.Module):
    def __init__(self, input_dim=5, output_dim=4):
        super(PointArrayFeatureExtractor, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.output_layer = nn.Linear(128, output_dim)

    def forward(self, point_cloud):
        batch_size, num_points, _ = point_cloud.shape
        point_features = self.feature_extractor(point_cloud)
        global_feature = torch.max(point_features, dim=1)[0]
        return self.output_layer(global_feature)
""" The Actor Network"""
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3):
        super(Actor, self).__init__()
        """Primary Task"""
        self.linear1_1 = nn.Linear(state_dim, hidden_dim)
        self.linear2_1 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3_1 = nn.Linear(hidden_dim, hidden_dim)
        self.mu_layer = nn.Linear(hidden_dim, action_dim - 3)
        self.std_layer = nn.Linear(hidden_dim, action_dim - 3)
        """Secondary Task"""
        self.linear1_2 = nn.Linear(state_dim, hidden_dim)
        self.linear2_2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3_2 = nn.Linear(hidden_dim, hidden_dim)
        self.softmax_layer = nn.Linear(hidden_dim, 3)

        self.mu_layer.weight.data.uniform_(-init_w, init_w)
        self.mu_layer.bias.data.uniform_(-init_w, init_w)
        self.std_layer.weight.data.uniform_(-init_w, init_w)
        self.std_layer.bias.data.uniform_(-init_w, init_w)
        self.softmax_layer.weight.data.uniform_(-init_w, init_w)
        self.softmax_layer.bias.data.uniform_(-init_w, init_w)

    def forward(self, x):
        x_1 = F.relu(self.linear1_1(x))
        x_1 = F.relu(self.linear2_1(x_1))
        x_1 = F.relu(self.linear3_1(x_1))
        mu = torch.tanh(self.mu_layer(x_1))
        std = torch.sigmoid(self.std_layer(x_1)) + 1e-6

        x_2 = F.relu(self.linear1_2(x))
        x_2 = F.relu(self.linear2_2(x_2))
        x_2 = F.relu(self.linear3_2(x_2))
        softmax_out = F.softmax(self.softmax_layer(x_2), dim=-1)

        return mu, std, softmax_out
""" The Primary Critic"""
class Critic_Pri(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3):
        super(Critic_Pri, self).__init__()
        self.linear1 = nn.Linear(state_dim + (action_dim - 3), hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)

    def forward(self, state, pri_action):
        x = torch.cat([state, pri_action], dim=1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        return self.linear3(x)
""" The Secondary Critic"""
class Critic_Sec(nn.Module):
    def __init__(self, state_dim, hidden_dim, init_w=3e-3):
        super(Critic_Sec, self).__init__()
        self.linear1 = nn.Linear(state_dim + 3, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)

    def forward(self, state, sec_action):
        x = torch.cat([state, sec_action], dim=1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        return self.linear3(x)
""" The PAN Network """
class PAN_Agent:
    def __init__(self, cfg):
        """Agent Design"""
        """Parameters Initialization"""
        self.target_entropy = cfg.target_entropy
        self.batch_size = cfg.batch_size
        self.gamma = cfg.gamma
        self.alpha = torch.tensor(cfg.alpha, dtype=torch.float32, requires_grad=True, device=self.device)
        self.soft_tau = cfg.soft_tau
        self.alpha_optimizer = optim.Adam([self.alpha], lr=3e-4)
        self.reward_scale = cfg.reward_scale
        self.total_it = 0
        self.target_update = cfg.target_update
        self.freeze_softmax_steps = 10000
        for param in self.actor.softmax_layer.parameters():
            param.requires_grad = False
        self.device = cfg.device
        """Network Initialization"""
        self.actor = Actor(cfg.state_dim, cfg.action_dim, cfg.hidden_dim).to(cfg.device)
        self.critic_pri_1 = Critic_Pri(cfg.state_dim, cfg.action_dim, cfg.hidden_dim).to(cfg.device)
        self.critic_pri_2 = Critic_Pri(cfg.state_dim, cfg.action_dim, cfg.hidden_dim).to(cfg.device)
        self.critic_sec_1 = Critic_Sec(cfg.state_dim, cfg.hidden_dim).to(cfg.device)
        self.critic_sec_2 = Critic_Sec(cfg.state_dim, cfg.hidden_dim).to(cfg.device)
        self.critic_pri_1_target = Critic_Pri(cfg.state_dim, cfg.action_dim, cfg.hidden_dim).to(cfg.device)
        self.critic_pri_2_target = Critic_Pri(cfg.state_dim, cfg.action_dim, cfg.hidden_dim).to(cfg.device)
        self.critic_sec_1_target = Critic_Sec(cfg.state_dim, cfg.hidden_dim).to(cfg.device)
        self.critic_sec_2_target = Critic_Sec(cfg.state_dim, cfg.hidden_dim).to(cfg.device)
        for target_param, param in zip(self.critic_pri_1_target.parameters(), self.critic_pri_1.parameters()):
            target_param.data.copy_(param.data)
        for target_param, param in zip(self.critic_pri_2_target.parameters(), self.critic_pri_2.parameters()):
            target_param.data.copy_(param.data)
        for target_param, param in zip(self.critic_sec_1_target.parameters(), self.critic_sec_1.parameters()):
            target_param.data.copy_(param.data)
        for target_param, param in zip(self.critic_sec_2_target.parameters(), self.critic_sec_2.parameters()):
            target_param.data.copy_(param.data)
        """Optimization Initialization"""
        self.actor_optimizer_pri = optim.Adam([
            {"params": self.actor.linear1_1.parameters()},
            {"params": self.actor.linear2_1.parameters()},
            {"params": self.actor.linear3_1.parameters()},
            {"params": self.actor.mu_layer.parameters()},
            {"params": self.actor.std_layer.parameters()},
        ], lr=1e-3)
        self.actor_optimizer_sec = optim.Adam([
            {"params": self.actor.linear1_2.parameters()},
            {"params": self.actor.linear2_2.parameters()},
            {"params": self.actor.linear3_2.parameters()},
            {"params": self.actor.softmax_layer.parameters()},
        ], lr=1e-3)
        self.critic_pri_1_optimizer = optim.Adam(self.critic_pri_1.parameters(), lr=cfg.critic_lr)
        self.critic_pri_2_optimizer = optim.Adam(self.critic_pri_2.parameters(), lr=cfg.critic_lr)
        self.critic_sec_1_optimizer = optim.Adam(self.critic_sec_1.parameters(), lr=cfg.critic_lr)
        self.critic_sec_2_optimizer = optim.Adam(self.critic_sec_2.parameters(), lr=cfg.critic_lr)
        """Memory Initialization"""
        self.memory = ReplayBuffer(cfg.memory_capacity)
    """Agent Updating"""
    def update(self):
        if len(self.memory) < self.batch_size:
            return
        self.total_it += 1
        states, actions, rewards_pri, rewards_sec, next_states, dones = self.memory.sample(self.batch_size)
        """ Batch size data processing """
        batch_state = torch.tensor(np.array(states), dtype=torch.float32).to(self.device)
        batch_next_state = torch.tensor(np.array(next_states), dtype=torch.float32).to(self.device)
        batch_action = torch.tensor(np.array(actions), dtype=torch.float32).to(self.device)
        batch_reward_pri = torch.tensor(np.array(rewards_pri), dtype=torch.float32).unsqueeze(1).to(self.device)
        batch_reward_sec = torch.tensor(np.array(rewards_sec), dtype=torch.float32).unsqueeze(1).to(self.device)
        batch_done = torch.tensor(np.array(dones), dtype=torch.float32).unsqueeze(1).to(self.device)
        batch_pri_action = batch_action[:, :-3]
        batch_sec_action = batch_action[:, -3:]
        with torch.no_grad():
            next_pri, next_std, next_sec = self.actor(batch_next_state)
            next_dist = torch.distributions.Normal(next_pri, next_std)
            next_sampled_action = torch.clamp(next_dist.rsample(), -1, 1)
            target_q_pri = torch.min(
                self.critic_pri_1_target(batch_next_state, next_sampled_action),
                self.critic_pri_2_target(batch_next_state, next_sampled_action)
            )
            target_q_sec = torch.min(
                self.critic_sec_1_target(batch_next_state, next_sec),
                self.critic_sec_2_target(batch_next_state, next_sec)
            )
            target_q_sec = target_q_sec.mean(dim=-1, keepdim=True)
            target_q_pri = batch_reward_pri + (1 - batch_done) * self.gamma * target_q_pri
            target_q_sec = batch_reward_sec + (1 - batch_done) * self.gamma * target_q_sec
        current_q_pri1 = self.critic_pri_1(batch_state, batch_pri_action)
        current_q_pri2 = self.critic_pri_2(batch_state, batch_pri_action)
        q_pri_loss = F.mse_loss(current_q_pri1, target_q_pri) + F.mse_loss(current_q_pri2, target_q_pri)
        current_q_sec1 = self.critic_sec_1(batch_state, batch_sec_action)
        current_q_sec2 = self.critic_sec_2(batch_state, batch_sec_action)
        q_sec_loss = F.mse_loss(current_q_sec1, target_q_sec) + F.mse_loss(current_q_sec2, target_q_sec)
        self.critic_pri_1_optimizer.zero_grad()
        self.critic_pri_2_optimizer.zero_grad()
        q_pri_loss.backward()
        self.critic_pri_1_optimizer.step()
        self.critic_pri_2_optimizer.step()
        self.critic_sec_1_optimizer.zero_grad()
        self.critic_sec_2_optimizer.zero_grad()
        q_sec_loss.backward()
        self.critic_sec_1_optimizer.step()
        self.critic_sec_2_optimizer.step()
        pri, std, sec = self.actor(batch_state)
        dist = torch.distributions.Normal(pri, std)
        sampled_action = torch.clamp(dist.rsample(), -1, 1)
        actor_q_pri = self.critic_pri_1(batch_state, sampled_action)
        actor_pri_loss = -actor_q_pri.mean()
        actor_q_sec = self.critic_sec_1(batch_state, sec)
        actor_sec_loss = -actor_q_sec.mean()
        self.actor_optimizer_pri.zero_grad()
        actor_pri_loss.backward(retain_graph=True)
        self.actor_optimizer_pri.step()
        self.actor_optimizer_sec.zero_grad()
        actor_sec_loss.backward()
        self.actor_optimizer_sec.step()
    """ Action Choosing"""
    def choose_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        pri, std, sec = self.actor(state)
        dist = torch.distributions.Normal(pri, std)
        sampled_action = dist.sample()
        sampled_action = torch.clamp(sampled_action, -1, 1)
        action = torch.cat([sampled_action, sec], dim=-1)
        return action.squeeze(0).detach().cpu().numpy()
    """ Model Saving"""
    def save(self, i, path, eps):
        checkpoint = {
            'actor': self.actor.state_dict(),
            'critic_pri_1': self.critic_pri_1.state_dict(),
            'critic_pri_2': self.critic_pri_2.state_dict(),
            'critic_sec_1': self.critic_sec_1.state_dict(),
            'critic_sec_2': self.critic_sec_2.state_dict()
        }
        torch.save(checkpoint, os.path.join(path, f'pan_checkpoint_{i}_{eps}.pt'))
    """ Model Loading"""
    def load(self, i, path, eps):
        checkpoint = torch.load(os.path.join(path, f'pan_checkpoint_{i}_{eps}.pt'))
        self.actor.load_state_dict(checkpoint['actor'])
        self.critic_pri_1.load_state_dict(checkpoint['critic_pri_1'])
        self.critic_pri_2.load_state_dict(checkpoint['critic_pri_2'])
        self.critic_sec_1.load_state_dict(checkpoint['critic_sec_1'])
        self.critic_sec_2.load_state_dict(checkpoint['critic_sec_2'])
