import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random


# Set device to cpu or cuda
device = torch.device('cpu')


if torch.cuda.is_available():
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


# Replay Buffer for storing transitions
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def sample_and_return_buffer_format(self, batch_size):

        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))

        # Put in the ReplayBuffer format
        batch = []
        for i in range(batch_size):
            batch.append((state[i], action[i], reward[i], next_state[i], done[i]))

        return batch

    def clear(self):
        self.buffer = []
        self.position = 0

    def __len__(self):
        return len(self.buffer)


# Q-Network
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()

        self.net = nn.Sequential(
            layer_init(nn.Linear(state_dim, 256)),
            nn.ReLU(),
            layer_init(nn.Linear(256, 256)),
            nn.ReLU(),
            layer_init(nn.Linear(256, action_dim), std=0.1)
        )

    def forward(self, state):
        return self.net(state)


# Double DQN Agent
class DDQN:
    def __init__(self, state_dim, action_dim, lr=1e-4, gamma=0.99, epsilon_start=0.2,
                 epsilon_end=0.01, epsilon_decay=0.9, tau=0.005):
        self.action_dim = action_dim
        self.gamma = gamma
        self.tau = tau
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay

        # Q-Network
        self.q_network = QNetwork(state_dim, action_dim).to(device)
        self.target_network = QNetwork(state_dim, action_dim).to(device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)

    def select_action(self, state):
        # Handle single state case
        if len(state.shape) == 1:
            if random.random() < self.epsilon:
                action = random.randint(0, self.action_dim - 1)
            else:
                q_values = self.q_network(torch.FloatTensor(state).to(device))
                action = q_values.argmax().item()
            return action

        # Handle batched states
        # Generate random numbers [0,1) for each state in the batch
        random_numbers = np.random.random(state.shape[0])
        # Create a boolean mask where True means "take random action"
        random_mask = random_numbers < self.epsilon

        # Initialize array for all actions
        actions = np.zeros(state.shape[0], dtype=np.int64)

        # Fill in random actions where mask is True
        actions[random_mask] = np.random.randint(0, self.action_dim, size=random_mask.sum())

        # Get Q-values for all states and fill in greedy actions where mask is False
        q_values = self.q_network(torch.FloatTensor(state).to(device))
        actions[~random_mask] = q_values[~random_mask].argmax(dim=1).cpu().numpy()

        return actions

    def update(self, replay_buffer, batch_size):
        # Sample a batch from memory
        state, action, reward, next_state, done = replay_buffer.sample(batch_size)

        state = torch.FloatTensor(state).to(device)
        action = torch.LongTensor(action).to(device)
        reward = torch.FloatTensor(reward).to(device).unsqueeze(1)
        next_state = torch.FloatTensor(next_state).to(device)
        done = torch.FloatTensor(np.float32(done)).to(device).unsqueeze(1)

        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the columns of actions taken
        q_values = self.q_network(state).gather(1, action.unsqueeze(1))

        # Compute argmax_a Q(s_{t+1}, a) for next state
        next_action = self.q_network(next_state).argmax(1, keepdim=True)

        # Compute Q(s_{t+1}, argmax_a Q(s_{t+1}, a)) for next state
        next_q_values = self.target_network(next_state).gather(1, next_action)

        # Compute the expected Q values
        expected_q_values = reward + (1 - done) * self.gamma * next_q_values

        # Compute Huber loss
        loss = F.smooth_l1_loss(q_values, expected_q_values.detach())

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update target network
        for param, target_param in zip(self.q_network.parameters(), self.target_network.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        # Update epsilon
        self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)

        return loss.item()

    def save(self, filename):
        torch.save(self.q_network.state_dict(), filename + "_q_network")
        torch.save(self.optimizer.state_dict(), filename + "_optimizer")

    def load(self, filename):
        self.q_network.load_state_dict(torch.load(filename + "_q_network"))
        self.optimizer.load_state_dict(torch.load(filename + "_optimizer"))
        self.target_network.load_state_dict(self.q_network.state_dict())
