import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import namedtuple

def fan_in_uniform_init(tensor, fan_in=None):
    """Utility function for initializing actor and critic"""
    if fan_in is None:
        fan_in = tensor.size(-1)

    w = 1. / np.sqrt(fan_in)
    nn.init.uniform_(tensor, -w, w)

def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

Transition = namedtuple('Transition', ('state', 'action', 'done', 'next_state', 'reward'))

class SumTree:
    def __init__(self, size):
        self.nodes = [0] * (2 * size - 1)
        self.data = [None] * size

        self.size = size
        self.count = 0
        self.real_size = 0

    @property
    def total(self):
        return self.nodes[0]

    def update(self, data_idx, value):
        idx = data_idx + self.size - 1  # child index in tree array
        change = value - self.nodes[idx]

        self.nodes[idx] = value

        parent = (idx - 1) // 2
        while parent >= 0:
            self.nodes[parent] += change
            parent = (parent - 1) // 2

    def add(self, value, data):
        self.data[self.count] = data
        self.update(self.count, value)

        self.count = (self.count + 1) % self.size
        self.real_size = min(self.size, self.real_size + 1)

    def get(self, cumsum):
        assert cumsum <= self.total

        idx = 0
        while 2 * idx + 1 < len(self.nodes):
            left, right = 2*idx + 1, 2*idx + 2

            if cumsum <= self.nodes[left]:
                idx = left
            else:
                idx = right
                cumsum = cumsum - self.nodes[left]

        data_idx = idx - self.size + 1

        return data_idx, self.nodes[idx], self.data[data_idx]

    def __repr__(self):
        return f"SumTree(nodes={self.nodes.__repr__()}, data={self.data.__repr__()})"

class PrioritizedReplayBuffer:
    def __init__(self, device, state_size, action_size, buffer_size, eps=1e-2, alpha=0.6, beta=0.4):
        self.tree = SumTree(size=buffer_size)

        self.device = device

        # PER params
        self.eps = eps  # minimal priority, prevents zero probabilities
        self.alpha = alpha  # determines how much prioritization is used, α = 0 corresponding to the uniform case
        self.beta = beta  # determines the amount of importance-sampling correction, b = 1 fully compensate for the non-uniform probabilities
        self.max_priority = eps  # priority for new samples, init as eps

        # transition: state, action, reward, next_state, done
        self.state = torch.empty(buffer_size, state_size, dtype=torch.float)
        self.action = torch.empty(buffer_size, 1, dtype=torch.float)
        self.reward = torch.empty(buffer_size, dtype=torch.float)
        self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float)
        self.done = torch.empty(buffer_size, dtype=torch.int)

        self.count = 0
        self.real_size = 0
        self.size = buffer_size

    def add(self, transition):
        state, action, reward, next_state, done = transition

        # store transition index with maximum priority in sum tree
        self.tree.add(self.max_priority, self.count)

        # store transition in the buffer
        self.state[self.count] = torch.as_tensor(state)
        self.action[self.count] = torch.as_tensor(action)
        self.reward[self.count] = torch.as_tensor(reward)
        self.next_state[self.count] = torch.as_tensor(next_state)
        self.done[self.count] = torch.as_tensor(done)

        # update counters
        self.count = (self.count + 1) % self.size
        self.real_size = min(self.size, self.real_size + 1)

    def sample(self, batch_size):
        assert self.real_size >= batch_size, "buffer contains less samples than batch size"

        sample_idxs, tree_idxs = [], []
        priorities = torch.empty(batch_size, 1, dtype=torch.float)

        # To sample a minibatch of size k, the range [0, p_total] is divided equally into k ranges.
        # Next, a value is uniformly sampled from each range. Finally the transitions that correspond
        # to each of these sampled values are retrieved from the tree. (Appendix B.2.1, Proportional prioritization)
        segment = self.tree.total / batch_size
        for i in range(batch_size):
            a, b = segment * i, segment * (i + 1)

            cumsum = random.uniform(a, b)
            # sample_idx is a sample index in buffer, needed further to sample actual transitions
            # tree_idx is a index of a sample in the tree, needed further to update priorities
            tree_idx, priority, sample_idx = self.tree.get(cumsum)

            priorities[i] = priority
            tree_idxs.append(tree_idx)
            sample_idxs.append(sample_idx)

        # Concretely, we define the probability of sampling transition i as P(i) = p_i^α / \sum_{k} p_k^α
        # where p_i > 0 is the priority of transition i. (Section 3.3)
        probs = priorities / self.tree.total

        # The estimation of the expected value with stochastic updates relies on those updates corresponding
        # to the same distribution as its expectation. Prioritized replay introduces bias because it changes this
        # distribution in an uncontrolled fashion, and therefore changes the solution that the estimates will
        # converge to (even if the policy and state distribution are fixed). We can correct this bias by using
        # importance-sampling (IS) weights w_i = (1/N * 1/P(i))^β that fully compensates for the non-uniform
        # probabilities P(i) if β = 1. These weights can be folded into the Q-learning update by using w_i * δ_i
        # instead of δ_i (this is thus weighted IS, not ordinary IS, see e.g. Mahmood et al., 2014).
        # For stability reasons, we always normalize weights by 1/maxi wi so that they only scale the
        # update downwards (Section 3.4, first paragraph)
        weights = (self.real_size * probs) ** -self.beta

        # As mentioned in Section 3.4, whenever importance sampling is used, all weights w_i were scaled
        # so that max_i w_i = 1. We found that this worked better in practice as it kept all weights
        # within a reasonable range, avoiding the possibility of extremely large updates. (Appendix B.2.1, Proportional prioritization)
        weights = weights / weights.max()

        batch = (
            self.state[sample_idxs].to(self.device),
            self.action[sample_idxs].to(self.device),
            self.reward[sample_idxs].to(self.device),
            self.next_state[sample_idxs].to(self.device),
            self.done[sample_idxs].to(self.device)
        )
        return batch, weights, tree_idxs

    def update_priorities(self, data_idxs, priorities):
        if isinstance(priorities, torch.Tensor):
            priorities = priorities.detach().cpu().numpy()

        for data_idx, priority in zip(data_idxs, priorities):
            # The first variant we consider is the direct, proportional prioritization where p_i = |δ_i| + eps,
            # where eps is a small positive constant that prevents the edge-case of transitions not being
            # revisited once their error is zero. (Section 3.3)
            priority = (priority + self.eps) ** self.alpha

            self.tree.update(data_idx, priority)
            self.max_priority = max(self.max_priority, priority)

    def __len__(self):
        return self.real_size

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class Critic(nn.Module):
    def __init__(self, state_dim_critic, action_dim):
        super(Critic, self).__init__()

        self.bn1 = nn.BatchNorm1d(state_dim_critic)
        self.linear1 = nn.Linear(state_dim_critic, 512)
        self.linear2 = nn.Linear(512, 256)
        self.Q = nn.Linear(256, action_dim)

        # Weight Init
        fan_in_uniform_init(self.linear1.weight)
        fan_in_uniform_init(self.linear1.bias)
        fan_in_uniform_init(self.linear2.weight)
        fan_in_uniform_init(self.linear2.bias)
        nn.init.uniform_(self.Q.weight, -3e-3, 3e-3)
        nn.init.uniform_(self.Q.bias, -3e-4, 3e-4)

    def forward(self, state):
        x = self.bn1(state)
        x = self.linear1(x)
        x = F.tanh(x)
        x = self.linear2(x)
        x = F.tanh(x)
        q_values = self.Q(x)
        return q_values

class DQN:
    def __init__(self, device, state_dim, action_dim, lr, gamma, tau, memory_capacity, batch_size, K_epochs):

        self.device = device
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.K_epochs = K_epochs
        self.num_actions = action_dim
        
        self.memory = PrioritizedReplayBuffer(device, state_dim, action_dim, memory_capacity)

        # Define the Q function and its target
        self.Q = Critic(state_dim, action_dim).to(self.device)
        self.Q_target = Critic(state_dim, action_dim).to(self.device)

        # Copy weights from Q to Q_target
        self.Q_target.load_state_dict(self.Q.state_dict())
        self.Q_target.eval()

        # Optimizer
        self.Q_optimizer = torch.optim.Adam(self.Q.parameters(), lr=lr, weight_decay=0.01)

        self.Q.eval()

    def select_action(self, state, epsilon=0.1):
        """
        Selects an action using epsilon-greedy strategy.
        """
        if np.random.rand() < epsilon:
            # Exploration: random action
            return np.random.randint(self.num_actions)

        with torch.no_grad():
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            q_values = self.Q(state)
            action = q_values.argmax(dim=1)
        
        return action.item()
    
    def update(self):
        """
        Updates the Q-network using a prioritized replay buffer.
        Includes:
            1. Sampling with priority and importance weights
            2. Computing target Q-values using the target network
            3. Computing TD error and applying weighted MSE loss
            4. Updating target network via soft update
        """
        self.Q.train()

        for _ in range(self.K_epochs):
            # Sample batch from replay buffer
            batch, weights, tree_idxs = self.memory.sample(self.batch_size)
            state_batch, action_batch, reward_batch, next_state_batch, done_batch = batch

            reward_batch = reward_batch.view(-1, 1)
            done_batch = done_batch.view(-1, 1)
            weights = weights.view(-1, 1).to(self.device)

            state_batch = state_batch.to(self.device)
            next_state_batch = next_state_batch.to(self.device)
            action_batch = action_batch.long().to(self.device)
            reward_batch = reward_batch.to(self.device)
            done_batch = done_batch.to(self.device)

            # Compute target Q-values
            with torch.no_grad():
                next_q_values = self.Q_target(next_state_batch)
                max_next_q = next_q_values.max(dim=1, keepdim=True)[0]
                target_q = reward_batch + (1.0 - done_batch) * self.gamma * max_next_q

            # Compute current Q-values
            current_q = self.Q(state_batch).gather(1, action_batch)

            # Compute TD error
            td_errors = current_q - target_q

            # Compute loss
            loss = (weights * td_errors.pow(2)).mean()

            # Optimize Q-network
            self.Q_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.Q.parameters(), max_norm=1.0)
            self.Q_optimizer.step()

            # Update priorities
            # new_priorities = td_errors.detach().abs().cpu().numpy().squeeze()
            new_priorities = td_errors.detach().abs().cpu().squeeze().numpy()
            self.memory.update_priorities(tree_idxs, new_priorities)

        # Soft update target network
        # soft_update(self.Q_target, self.Q, self.tau)

        self.Q.eval()

        return loss.item()

    def save(self, checkpoint_path):
        torch.save(self.Q.state_dict(), checkpoint_path + "_Q.pth")
        torch.save(self.Q_target.state_dict(), checkpoint_path + "_Q_target.pth")
   
    def load(self, checkpoint_path):
        checkpoint_Q = torch.load(checkpoint_path + "_Q.pth", map_location=self.device, weights_only=True)
        checkpoint_Q_target = torch.load(checkpoint_path + "_Q_target.pth", map_location=self.device, weights_only=True)
        self.Q.load_state_dict(checkpoint_Q)
        self.Q_target.load_state_dict(checkpoint_Q_target)
        self.Q.to(self.device)
        self.Q_target.to(self.device)