import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import torch.nn.functional as F
import os
""" The Actor Network"""
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, max_action, init_w=3e-3):
        super(Actor, self).__init__()
        self.linear1 = nn.Linear(state_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, action_dim)
        self.max_action = max_action
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = torch.tanh(self.linear3(x)) * self.max_action
        return x
"""The Critic Network"""
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3):
        super(Critic, self).__init__()
        self.linear1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x
"""The BPN network"""
class BatteryPredictionNetwork(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=256):
        super(BatteryPredictionNetwork, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return self.model(x)

"""Operation of Equation (9)"""
def normalize_adjacency_matrix(A):
    A_hat = A + torch.eye(A.shape[1], device=A.device)
    D = torch.sum(A_hat, dim=2)
    D_inv_sqrt = torch.diag_embed(1.0 / torch.sqrt(D))
    A_norm = torch.matmul(torch.matmul(D_inv_sqrt, A_hat), D_inv_sqrt)
    return A_norm
"""The GCN layer"""
class GCNLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GCNLayer, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, node_features, adjacency_matrix):
        A_norm = normalize_adjacency_matrix(adjacency_matrix)
        support = A_norm @ node_features
        out = self.fc(support)
        return out
"""The GNN network"""
class GNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNN, self).__init__()
        self.layer1 = GCNLayer(input_dim, hidden_dim)
        self.layer2 = GCNLayer(hidden_dim, output_dim)

    def forward(self, node_features, adjacency_matrix):
        x = F.relu(self.layer1(node_features, adjacency_matrix))
        x = self.layer2(x, adjacency_matrix)
        global_feature = torch.mean(x, dim=1)
        return global_feature
"""The GNN agent"""
class GNN_Agent:
    def __init__(self, config):
        """Agent Design"""
        """Parameters Initialization"""
        self.device = config.device
        self.gamma = config.gamma
        self.tau = 0.005
        self.policy_noise = 0.2
        self.noise_clip = 0.5
        self.policy_delay = 2
        self.max_action = config.max_movement
        self.hidden_dim = config.hidden_dim
        self.action_dim = config.action_dim
        """Network Initialization"""
        self.actor = Actor(config.state_dim, self.action_dim, self.hidden_dim, self.max_action).to(self.device)
        self.actor_target = Actor(config.state_dim, self.action_dim, self.hidden_dim, self.max_action).to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.gnn = GNN(config.node_feature_dim, config.gnn_hidden_dim, config.gnn_output_dim).to(config.device)
        self.critic_1 = Critic(config.state_dim, self.action_dim, self.hidden_dim).to(self.device)
        self.critic_2 = Critic(config.state_dim, self.action_dim, self.hidden_dim).to(self.device)
        self.critic_target_1 = Critic(config.state_dim, self.action_dim, self.hidden_dim).to(self.device)
        self.critic_target_2 = Critic(config.state_dim, self.action_dim, self.hidden_dim).to(self.device)
        self.critic_target_1.load_state_dict(self.critic_1.state_dict())
        self.critic_target_2.load_state_dict(self.critic_2.state_dict())
        """Optimization Initialization"""
        self.gnn_optimizer = optim.Adam(self.gnn.parameters(), lr=config.gnn_lr)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=config.actor_lr)
        self.critic_optimizer_1 = optim.Adam(self.critic_1.parameters(), lr=config.critic_lr)
        self.critic_optimizer_2 = optim.Adam(self.critic_2.parameters(), lr=config.critic_lr)
        """Memory Initialization"""
        self.memory_size = config.memory_size
        self.replay_buffer = []
        self.total_it = 0
    """ The computation of GNN loss as Equation (10)"""
    def compute_gnn_loss(self, batch_state, batch_action, batch_reward, batch_done):
        current_q_1 = self.critic_1(batch_state, batch_action)
        current_q_2 = self.critic_2(batch_state, batch_action)
        current_q = batch_reward + self.gamma * (1 - batch_done) * torch.min(current_q_1, current_q_2)
        gnn_loss = F.mse_loss(current_q_1, current_q) + F.mse_loss(current_q_2, current_q)
        return gnn_loss
    """ Action Choosing"""
    def choose_action(self, state, explore=True):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        with torch.no_grad():
            action = self.actor(state).cpu().numpy().flatten()
        if explore:
            """Add Noise"""
            noise = np.random.normal(0, self.policy_noise, size=self.action_dim)
            noise = np.clip(noise, -self.noise_clip, self.noise_clip)
            action = np.clip(action + noise, -self.max_action, self.max_action)
        return action
    """Data Storage"""
    def store_transition(self, state, action, reward, next_state, done):
        self.replay_buffer.append((state, action, reward, next_state, done))
        if len(self.replay_buffer) > self.memory_size:
            self.replay_buffer.pop(0)
    """Agent Updating"""
    def update(self):
        if len(self.replay_buffer) < 128:
            return

        self.total_it += 1
        """ Batch size data processing """
        batch = random.sample(self.replay_buffer, 128)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(np.array(states)).to(self.device)
        actions = torch.FloatTensor(np.array(actions)).to(self.device)
        rewards = torch.FloatTensor(np.array(rewards)).unsqueeze(1).to(self.device)
        next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
        dones = torch.FloatTensor(np.array(dones)).unsqueeze(1).to(self.device)

        with torch.no_grad():
            noise = (torch.randn_like(actions) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.actor_target(next_states) + noise).clamp(-self.max_action, self.max_action)
            target_q1 = self.critic_target_1(next_states, next_action)
            target_q2 = self.critic_target_2(next_states, next_action)
            target_q = rewards + self.gamma * (1 - dones) * torch.min(target_q1, target_q2)

        current_q1 = self.critic_1(states, actions)
        current_q2 = self.critic_2(states, actions)
        critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

        self.critic_optimizer_1.zero_grad()
        self.critic_optimizer_2.zero_grad()
        critic_loss.backward()
        self.critic_optimizer_1.step()
        self.critic_optimizer_2.step()

        gnn_loss = self.compute_gnn_loss(states, actions, rewards, dones)
        self.gnn_optimizer.zero_grad()
        gnn_loss.backward()
        self.gnn_optimizer.step()

        if self.total_it % self.policy_delay == 0:
            actor_loss = -self.critic_1(states, self.actor(states)).mean()
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            for param, target_param in zip(self.critic_1.parameters(), self.critic_target_1.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.critic_2.parameters(), self.critic_target_2.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    """ Model Saving"""
    def save(self, i, path, eps):
        os.makedirs(path, exist_ok=True)
        checkpoint = {
            'actor': self.actor.state_dict(),
            'critic_1': self.critic_1.state_dict(),
            'critic_2': self.critic_2.state_dict(),
            'gnn': self.gnn.state_dict()
        }
        torch.save(checkpoint, os.path.join(path, f'gnn_checkpoint_{i}_{eps}.pt'))
    """ Model Loading"""
    def load(self, i, path, eps):
        filename = os.path.join(path, f'gnn_checkpoint_{i}_{eps}.pt')
        if not os.path.exists(filename):
            raise FileNotFoundError(f"Checkpoint file not found: {filename}")
        checkpoint = torch.load(filename)
        self.actor.load_state_dict(checkpoint['actor'])
        self.critic_1.load_state_dict(checkpoint['critic_1'])
        self.critic_2.load_state_dict(checkpoint['critic_2'])
        self.gnn.load_state_dict(checkpoint['gnn'])
