import random
import collections
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = collections.deque(maxlen=capacity)
    
    def add(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states, dtype=np.float32),
                np.array(actions, dtype=np.int64),
                np.array(rewards, dtype=np.float32),
                np.array(next_states, dtype=np.float32),
                np.array(dones, dtype=np.float32))
    
    def __len__(self):
        return len(self.memory)

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class DQNAgent:
    def __init__(
        self, 
        state_dim, 
        action_dim, 
        replay_buffer,
        lr=1e-4, 
        gamma=0.99, 
        batch_size=64, 
        epsilon_start=1.0,
        epsilon_end=0.1, 
        epsilon_decay=10000, 
        target_update_freq=1000
    ):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.replay_buffer = replay_buffer
        self.gamma = gamma
        self.batch_size = batch_size
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.target_update_freq = target_update_freq
        
        self.learn_step = 0
        self.q_network = QNetwork(state_dim, action_dim)
        self.target_q_network = QNetwork(state_dim, action_dim)
        self.target_q_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                q_values = self.q_network(state_tensor)
            return torch.argmax(q_values, dim=1).item()

    def update(self, states, actions, rewards, next_states, dones):
        if len(self.replay_buffer) < self.batch_size:
            return

        self.learn_step += 1
        # states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        
        states_v = torch.FloatTensor(states)
        actions_v = torch.LongTensor(actions)
        rewards_v = torch.FloatTensor(rewards)
        next_states_v = torch.FloatTensor(next_states)
        dones_v = torch.FloatTensor(dones)

        # Current Q values
        q_values = self.q_network(states_v)
        q_action = q_values.gather(1, actions_v.unsqueeze(-1)).squeeze(-1)

        # Target Q values
        with torch.no_grad():
            next_q_values = self.target_q_network(next_states_v)
            next_q_max = torch.max(next_q_values, dim=1)[0]
            q_target = rewards_v + (1 - dones_v) * self.gamma * next_q_max

        loss = F.mse_loss(q_action, q_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update target network
        if self.learn_step % self.target_update_freq == 0:
            self.target_q_network.load_state_dict(self.q_network.state_dict())

        # Epsilon decay
        self.epsilon = max(
            self.epsilon_end, 
            self.epsilon - (1.0 - self.epsilon_end) / self.epsilon_decay
        )

class TrainingState:
    def __init__(self, max_epochs):
        self.max_epochs = max_epochs
        self.current_epoch = 0
        self.recent_loss = 0.0
        self.recent_grad = 0.0
        self.current_mAP = 0.0
        self.previous_mAP = 0.0
        self.mAP_trend = 0.0

    def update(self, epoch, epoch_loss, grad_norm, mAP_score):
        self.current_epoch = epoch
        # Save old mAP before updating
        self.previous_mAP = self.current_mAP
        self.current_mAP = mAP_score
        
        self.recent_loss = epoch_loss
        self.recent_grad = grad_norm
        
        # Calculate trend in mAP
        self.mAP_trend = self.current_mAP - self.previous_mAP

    def get_state_vector(self):
        normalized_epoch = float(self.current_epoch) / float(self.max_epochs)
        return [
            normalized_epoch,
            self.recent_loss,
            self.recent_grad,
            self.current_mAP,
            self.mAP_trend
        ]