import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import namedtuple

def fan_in_uniform_init(tensor, fan_in=None):
    """Utility function for initializing actor and critic"""
    if fan_in is None:
        fan_in = tensor.size(-1)

    w = 1. / np.sqrt(fan_in)
    nn.init.uniform_(tensor, -w, w)

def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

Transition = namedtuple('Transition', ('state', 'action', 'done', 'next_state', 'reward'))

class SumTree:
    def __init__(self, size):
        self.nodes = [0] * (2 * size - 1)
        self.data = [None] * size

        self.size = size
        self.count = 0
        self.real_size = 0

    @property
    def total(self):
        return self.nodes[0]

    def update(self, data_idx, value):
        idx = data_idx + self.size - 1  # child index in tree array
        change = value - self.nodes[idx]

        self.nodes[idx] = value

        parent = (idx - 1) // 2
        while parent >= 0:
            self.nodes[parent] += change
            parent = (parent - 1) // 2

    def add(self, value, data):
        self.data[self.count] = data
        self.update(self.count, value)

        self.count = (self.count + 1) % self.size
        self.real_size = min(self.size, self.real_size + 1)

    def get(self, cumsum):
        assert cumsum <= self.total

        idx = 0
        while 2 * idx + 1 < len(self.nodes):
            left, right = 2*idx + 1, 2*idx + 2

            if cumsum <= self.nodes[left]:
                idx = left
            else:
                idx = right
                cumsum = cumsum - self.nodes[left]

        data_idx = idx - self.size + 1

        return data_idx, self.nodes[idx], self.data[data_idx]

    def __repr__(self):
        return f"SumTree(nodes={self.nodes.__repr__()}, data={self.data.__repr__()})"

class PrioritizedReplayBuffer:
    def __init__(self, device, state_size, action_size, buffer_size, eps=1e-2, alpha=0.6, beta=0.4):
        self.tree = SumTree(size=buffer_size)

        self.device = device

        # PER params
        self.eps = eps  # minimal priority, prevents zero probabilities
        self.alpha = alpha  # determines how much prioritization is used, α = 0 corresponding to the uniform case
        self.beta = beta  # determines the amount of importance-sampling correction, b = 1 fully compensate for the non-uniform probabilities
        self.max_priority = eps  # priority for new samples, init as eps

        # transition: state, action, reward, next_state, done
        self.state = torch.empty(buffer_size, state_size, dtype=torch.float)
        self.action = torch.empty(buffer_size, action_size, dtype=torch.float)
        self.reward = torch.empty(buffer_size, dtype=torch.float)
        self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float)
        self.done = torch.empty(buffer_size, dtype=torch.int)

        self.count = 0
        self.real_size = 0
        self.size = buffer_size

    def add(self, transition):
        state, action, reward, next_state, done = transition

        # store transition index with maximum priority in sum tree
        self.tree.add(self.max_priority, self.count)

        # store transition in the buffer
        self.state[self.count] = torch.as_tensor(state)
        self.action[self.count] = torch.as_tensor(action)
        self.reward[self.count] = torch.as_tensor(reward)
        self.next_state[self.count] = torch.as_tensor(next_state)
        self.done[self.count] = torch.as_tensor(done)

        # update counters
        self.count = (self.count + 1) % self.size
        self.real_size = min(self.size, self.real_size + 1)

    def sample(self, batch_size):
        assert self.real_size >= batch_size, "buffer contains less samples than batch size"

        sample_idxs, tree_idxs = [], []
        priorities = torch.empty(batch_size, 1, dtype=torch.float)

        # To sample a minibatch of size k, the range [0, p_total] is divided equally into k ranges.
        # Next, a value is uniformly sampled from each range. Finally the transitions that correspond
        # to each of these sampled values are retrieved from the tree. (Appendix B.2.1, Proportional prioritization)
        segment = self.tree.total / batch_size
        for i in range(batch_size):
            a, b = segment * i, segment * (i + 1)

            cumsum = random.uniform(a, b)
            # sample_idx is a sample index in buffer, needed further to sample actual transitions
            # tree_idx is a index of a sample in the tree, needed further to update priorities
            tree_idx, priority, sample_idx = self.tree.get(cumsum)

            priorities[i] = priority
            tree_idxs.append(tree_idx)
            sample_idxs.append(sample_idx)

        # Concretely, we define the probability of sampling transition i as P(i) = p_i^α / \sum_{k} p_k^α
        # where p_i > 0 is the priority of transition i. (Section 3.3)
        probs = priorities / self.tree.total

        # The estimation of the expected value with stochastic updates relies on those updates corresponding
        # to the same distribution as its expectation. Prioritized replay introduces bias because it changes this
        # distribution in an uncontrolled fashion, and therefore changes the solution that the estimates will
        # converge to (even if the policy and state distribution are fixed). We can correct this bias by using
        # importance-sampling (IS) weights w_i = (1/N * 1/P(i))^β that fully compensates for the non-uniform
        # probabilities P(i) if β = 1. These weights can be folded into the Q-learning update by using w_i * δ_i
        # instead of δ_i (this is thus weighted IS, not ordinary IS, see e.g. Mahmood et al., 2014).
        # For stability reasons, we always normalize weights by 1/maxi wi so that they only scale the
        # update downwards (Section 3.4, first paragraph)
        weights = (self.real_size * probs) ** -self.beta

        # As mentioned in Section 3.4, whenever importance sampling is used, all weights w_i were scaled
        # so that max_i w_i = 1. We found that this worked better in practice as it kept all weights
        # within a reasonable range, avoiding the possibility of extremely large updates. (Appendix B.2.1, Proportional prioritization)
        weights = weights / weights.max()

        batch = (
            self.state[sample_idxs].to(self.device),
            self.action[sample_idxs].to(self.device),
            self.reward[sample_idxs].to(self.device),
            self.next_state[sample_idxs].to(self.device),
            self.done[sample_idxs].to(self.device)
        )
        return batch, weights, tree_idxs

    def update_priorities(self, data_idxs, priorities):
        if isinstance(priorities, torch.Tensor):
            priorities = priorities.detach().cpu().numpy()

        for data_idx, priority in zip(data_idxs, priorities):
            # The first variant we consider is the direct, proportional prioritization where p_i = |δ_i| + eps,
            # where eps is a small positive constant that prevents the edge-case of transitions not being
            # revisited once their error is zero. (Section 3.3)
            priority = (priority + self.eps) ** self.alpha

            self.tree.update(data_idx, priority)
            self.max_priority = max(self.max_priority, priority)

    def __len__(self):
        return self.real_size

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class OrnsteinUhlenbeckActionNoise:
    def __init__(self, mu, sigma, theta=.15, dt=1e-2, x0=None):
        self.theta = theta
        self.mu = mu
        self.sigma = sigma
        self.dt = dt
        self.x0 = x0
        self.reset()

    def noise(self):
        x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt \
            + self.sigma * np.sqrt(self.dt) * np.random.normal(size=self.mu.shape)
        self.x_prev = x
        return x

    def reset(self):
        self.x_prev = self.x0 if self.x0 is not None else np.zeros_like(self.mu)

    def __repr__(self):
        return 'OrnsteinUhlenbeckActionNoise(mu={}, sigma={})'.format(self.mu, self.sigma)

class Actor(nn.Module):
    def __init__(self, state_dim_actor, action_dim):

        super(Actor, self).__init__()

        self.bn1 = nn.BatchNorm1d(state_dim_actor)
        self.linear1 = nn.Linear(state_dim_actor, 512)
        self.linear2 = nn.Linear(512, 256)
        self.mu = nn.Linear(256, action_dim)

        # Weight Init
        fan_in_uniform_init(self.linear1.weight)
        fan_in_uniform_init(self.linear1.bias)
        fan_in_uniform_init(self.linear2.weight)
        fan_in_uniform_init(self.linear2.bias)
        nn.init.uniform_(self.mu.weight, -3e-3, 3e-3)
        nn.init.uniform_(self.mu.bias, -3e-4, 3e-4)
            
    def forward(self, inputs):

        x = inputs
        
        # Layer 1
        x = self.bn1(x)
        x = self.linear1(x)
        x = F.tanh(x)

        # Layer 2
        x = self.linear2(x)
        x = F.tanh(x)

        # Output
        mu = torch.tanh(self.mu(x))
        return mu

class Critic(nn.Module):
    def __init__(self, state_dim_critic, action_dim):

        super(Critic, self).__init__()

        self.bn1 = nn.BatchNorm1d(state_dim_critic)
        self.linear1 = nn.Linear(state_dim_critic, 512)
        self.linear2 = nn.Linear(512 + action_dim, 256)
        self.V = nn.Linear(256, 1)

        # Weight Init
        fan_in_uniform_init(self.linear1.weight)
        fan_in_uniform_init(self.linear1.bias)
        fan_in_uniform_init(self.linear2.weight)
        fan_in_uniform_init(self.linear2.bias)
        nn.init.uniform_(self.V.weight, -3e-3, 3e-3)
        nn.init.uniform_(self.V.bias, -3e-4, 3e-4)
            
    def forward(self, inputs, actions):

        x = inputs
        
        # Layer 1
        x = self.bn1(x)
        x = self.linear1(x)
        x = F.tanh(x)

        # Layer 2
        x = torch.cat((x, actions), 1)  # Insert the actions
        x = self.linear2(x)
        x = F.tanh(x)

        # Output
        V = self.V(x)
        return V

class DDPG:
    def __init__(self, device, state_dim_actor, state_dim_critic, action_dim, lr_actor, lr_critic, gamma, tau, memory_capacity, batch_size, K_epochs):

        self.device = device
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.K_epochs = K_epochs
        self.ou_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(action_dim), sigma=float(0.2) * np.ones(action_dim))
        
        self.memory = PrioritizedReplayBuffer(device, state_dim_actor, action_dim, memory_capacity)

        # Define the actor
        self.actor = Actor(state_dim_actor, action_dim).to(self.device)
        self.actor_target = Actor(state_dim_actor, action_dim).to(self.device)

        # Define the critic
        self.critic = Critic(state_dim_critic, action_dim).to(self.device)
        self.critic_target = Critic(state_dim_critic, action_dim).to(self.device)

        # Define the optimizers for both networks
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr_actor, weight_decay=0.01)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr_critic, weight_decay=0.01)

        # Make sure both targets are with the same weight
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_target.load_state_dict(self.critic.state_dict())

        self.actor.eval()
        self.actor_target.eval()
        self.critic.eval()
        self.critic_target.eval()

    def select_action(self, state, action_noise=None):
        with torch.no_grad():
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            action = self.actor(state)
            # During training we add noise for exploration
            if action_noise is not None:
                noise = torch.Tensor(action_noise.noise()).to(self.device)
                action += noise
            action = action.detach()
        return action.cpu().numpy().flatten()
    
    def update(self):
        """
        Updates the parameters/networks of the agent according to a prioritized batch.
        Includes:
            1. Sampling with priority and importance weights
            2. Updating critic using weighted loss
            3. Updating actor with policy gradient
            4. Updating target networks via soft updates
        """
        self.actor.train()
        self.actor_target.train()
        self.critic.train()
        self.critic_target.train()

        for _ in range(self.K_epochs):

            # Sample prioritized batch
            batch, weights, tree_idxs = self.memory.sample(self.batch_size)
            (state_batch, action_batch, reward_batch, next_state_batch, done_batch) = batch

            # Ensure proper tensor shapes
            reward_batch = reward_batch.view(-1, 1)
            done_batch = done_batch.view(-1, 1)

            # Compute target Q-values
            with torch.no_grad():
                next_actions = self.actor_target(next_state_batch)
                next_q_values = self.critic_target(next_state_batch, next_actions)
                expected_q_values = reward_batch + (1.0 - done_batch) * self.gamma * next_q_values

            # Critic update
            self.critic_optimizer.zero_grad()
            current_q_values = self.critic(state_batch, action_batch)
            td_errors = current_q_values - expected_q_values
            weights = weights.to(td_errors.device).view(-1, 1)
            value_loss = (weights * td_errors.pow(2)).mean()
            value_loss.backward()
            self.critic_optimizer.step()

            # Actor update
            self.actor_optimizer.zero_grad()
            policy_loss = -self.critic(state_batch, self.actor(state_batch)).mean()
            policy_loss.backward()
            self.actor_optimizer.step()

            # Soft update target networks
            soft_update(self.actor_target, self.actor, self.tau)
            soft_update(self.critic_target, self.critic, self.tau)

            # Update priorities in replay buffer
            new_priorities = td_errors.detach().abs().cpu().numpy().squeeze()
            self.memory.update_priorities(tree_idxs, new_priorities)

        self.actor.eval()
        self.actor_target.eval()
        self.critic.eval()
        self.critic_target.eval()

        return value_loss.item(), policy_loss.item()

        
    def update_old(self):
        """
        Updates the parameters/networks of the agent according to a batch.
        This means we ...
            1. Compute the targets
            2. Update the Q-function/critic by one step of gradient descent
            3. Update the policy/actor by one step of gradient ascent
            4. Update the target networks through a soft update
        """

        self.actor.train()
        self.actor_target.train()
        self.critic.train()
        self.critic_target.train()

        for _ in range(self.K_epochs):

            # Sample batch
            transitions = self.memory.sample(self.batch_size)
            batch = Transition(*zip(*transitions))

            # Convert each part of the batch to tensor
            state_batch = torch.tensor(np.array(batch.state), dtype=torch.float32).to(self.device)
            action_batch = torch.tensor(np.array(batch.action), dtype=torch.float32).to(self.device)
            reward_batch = torch.tensor(np.array(batch.reward), dtype=torch.float32).reshape(-1, 1).to(self.device)
            done_batch = torch.tensor(np.array(batch.done), dtype=torch.float32).reshape(-1, 1).to(self.device)
            next_state_batch = torch.tensor(np.array(batch.next_state), dtype=torch.float32).to(self.device)

            # Defensive reshaping to avoid broadcasting issues
            reward_batch = reward_batch.view(-1, 1)
            done_batch = done_batch.view(-1, 1)

            # Get the actions and the state values to compute the targets
            next_action_batch = self.actor_target(next_state_batch)
            next_state_action_values = self.critic_target(next_state_batch, next_action_batch.detach())

            # Compute the target
            expected_values = reward_batch + (1.0 - done_batch) * self.gamma * next_state_action_values

            # Update the critic network
            self.critic_optimizer.zero_grad()
            state_action_batch = self.critic(state_batch, action_batch)
            value_loss = F.mse_loss(state_action_batch, expected_values.detach())
            value_loss.backward()
            self.critic_optimizer.step()

            # Update the actor network
            self.actor_optimizer.zero_grad()
            policy_loss = -self.critic(state_batch, self.actor(state_batch))
            policy_loss = policy_loss.mean()
            policy_loss.backward()
            self.actor_optimizer.step()

            # Update the target networks
            soft_update(self.actor_target, self.actor, self.tau)
            soft_update(self.critic_target, self.critic, self.tau)

        self.actor.eval()
        self.actor_target.eval()
        self.critic.eval()
        self.critic_target.eval()

        return value_loss.item(), policy_loss.item()

    def save(self, checkpoint_path):
        torch.save(self.actor.state_dict(), checkpoint_path + "_actor.pth")
        torch.save(self.critic.state_dict(), checkpoint_path + "_critic.pth")
   
    def load(self, checkpoint_path):
        checkpoint_actor = torch.load(checkpoint_path + "_actor.pth", map_location=self.device, weights_only=True)
        checkpoint_critic = torch.load(checkpoint_path + "_critic.pth", map_location=self.device, weights_only=True)
        self.actor.load_state_dict(checkpoint_actor)
        self.actor_target.load_state_dict(checkpoint_actor)
        self.critic.load_state_dict(checkpoint_critic)
        self.critic_target.load_state_dict(checkpoint_critic)
        self.actor.to(self.device)
        self.actor_target.to(self.device)
        self.critic.to(self.device)
        self.critic_target.to(self.device)