import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import torch.nn.functional as F
import os
""" The Actor Network"""
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, max_action, init_w=3e-3):
        super(Actor, self).__init__()
        self.linear1 = nn.Linear(state_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, action_dim)
        self.max_action = max_action
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = torch.tanh(self.linear3(x)) * self.max_action
        return x
"""The Critic Network"""
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3):
        super(Critic, self).__init__()
        self.linear1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x
"""The BPN network"""
class BatteryPredictionNetwork(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=256):
        super(BatteryPredictionNetwork, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    def forward(self, x):
        return self.model(x)
""" The PAN Network """
class PointArrayFeatureExtractor(nn.Module):
    def __init__(self, input_dim=5, output_dim=4):
        super(PointArrayFeatureExtractor, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.output_layer = nn.Linear(128, output_dim)

    def forward(self, point_cloud):
        batch_size, num_points, _ = point_cloud.shape
        point_features = self.feature_extractor(point_cloud)
        global_feature = torch.max(point_features, dim=1)[0]
        return self.output_layer(global_feature)
"""The PAN agent"""
class PAN_Agent:
    def __init__(self, config):
        """Agent Design"""
        """Parameters Initialization"""
        self.device = config.device
        self.gamma = config.gamma
        self.tau = 0.005
        self.policy_noise = 0.2
        self.noise_clip = 0.5
        self.policy_delay = 2
        self.max_action = config.max_movement
        self.hidden_dim = config.hidden_dim
        self.action_dim = config.action_dim
        """Network Initialization"""
        self.actor = Actor(config.state_dim, self.action_dim, self.hidden_dim, self.max_action).to(self.device)
        self.actor_target = Actor(config.state_dim, self.action_dim, self.hidden_dim, self.max_action).to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_1 = Critic(config.state_dim, self.action_dim, self.hidden_dim).to(self.device)
        self.critic_2 = Critic(config.state_dim, self.action_dim, self.hidden_dim).to(self.device)
        self.critic_target_1 = Critic(config.state_dim, self.action_dim, self.hidden_dim).to(self.device)
        self.critic_target_2 = Critic(config.state_dim, self.action_dim, self.hidden_dim).to(self.device)
        self.critic_target_1.load_state_dict(self.critic_1.state_dict())
        self.critic_target_2.load_state_dict(self.critic_2.state_dict())
        """Optimization Initialization"""
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=config.actor_lr)
        self.critic_optimizer_1 = optim.Adam(self.critic_1.parameters(), lr=config.critic_lr)
        self.critic_optimizer_2 = optim.Adam(self.critic_2.parameters(), lr=config.critic_lr)
        """Memory Initialization"""
        self.memory_size = config.memory_size
        self.replay_buffer = []
        self.total_it = 0
    """ Action Choosing"""
    def choose_action(self, state, explore=True):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        with torch.no_grad():
            action = self.actor(state).cpu().numpy().flatten()
        if explore:
            noise = np.random.normal(0, self.policy_noise, size=self.action_dim)
            noise = np.clip(noise, -self.noise_clip, self.noise_clip)
            action = np.clip(action + noise, -self.max_action, self.max_action)
        return action
    """Data Storage"""
    def store_transition(self, state, action, reward, next_state, done):
        self.replay_buffer.append((state, action, reward, next_state, done))
        if len(self.replay_buffer) > self.memory_size:
            self.replay_buffer.pop(0)
    """Agent Updating"""
    def update(self):
        if len(self.replay_buffer) < 128:
            return
        self.total_it += 1
        """ Batch size data processing """
        batch = random.sample(self.replay_buffer, 128)
        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.FloatTensor(np.array(states)).to(self.device)
        actions = torch.FloatTensor(np.array(actions)).to(self.device)
        rewards = torch.FloatTensor(np.array(rewards)).unsqueeze(1).to(self.device)
        next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
        dones = torch.FloatTensor(np.array(dones)).unsqueeze(1).to(self.device)
        with torch.no_grad():
            noise = (torch.randn_like(actions) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.actor_target(next_states) + noise).clamp(-self.max_action, self.max_action)
            target_q1 = self.critic_target_1(next_states, next_action)
            target_q2 = self.critic_target_2(next_states, next_action)
            target_q = rewards + self.gamma * (1 - dones) * torch.min(target_q1, target_q2)
        current_q1 = self.critic_1(states, actions)
        current_q2 = self.critic_2(states, actions)
        critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
        self.critic_optimizer_1.zero_grad()
        self.critic_optimizer_2.zero_grad()
        critic_loss.backward()
        self.critic_optimizer_1.step()
        self.critic_optimizer_2.step()
        if self.total_it % self.policy_delay == 0:
            actor_loss = -self.critic_1(states, self.actor(states)).mean()
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            for param, target_param in zip(self.critic_1.parameters(), self.critic_target_1.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            for param, target_param in zip(self.critic_2.parameters(), self.critic_target_2.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    """ Model Saving"""
    def save(self, i, path, eps):
        os.makedirs(path, exist_ok=True)
        checkpoint = {
            'actor': self.actor.state_dict(),
            'critic_1': self.critic_1.state_dict(),
            'critic_2': self.critic_2.state_dict()
        }
        filename = os.path.join(path, f'pan_checkpoint_{i}_{eps}.pt')
        torch.save(checkpoint, filename)
    """ Model Loading"""
    def load(self, i, path, eps):
        filename = os.path.join(path, f'pan_checkpoint_{i}_{eps}.pt')
        if not os.path.exists(filename):
            raise FileNotFoundError(f"Checkpoint file not found: {filename}")
        checkpoint = torch.load(filename)
        self.actor.load_state_dict(checkpoint['actor'])
        self.critic_1.load_state_dict(checkpoint['critic_1'])
        self.critic_2.load_state_dict(checkpoint['critic_2'])
