import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import namedtuple

def fan_in_uniform_init(tensor, fan_in=None):
    """Utility function for initializing actor and critic"""
    if fan_in is None:
        fan_in = tensor.size(-1)

    w = 1. / np.sqrt(fan_in)
    nn.init.uniform_(tensor, -w, w)

def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

Transition = namedtuple('Transition', ('state', 'action', 'done', 'next_state', 'reward'))

class SumTree:
    def __init__(self, size):
        self.nodes = [0] * (2 * size - 1)
        self.data = [None] * size

        self.size = size
        self.count = 0
        self.real_size = 0

    @property
    def total(self):
        return self.nodes[0]

    def update(self, data_idx, value):
        idx = data_idx + self.size - 1  # child index in tree array
        change = value - self.nodes[idx]

        self.nodes[idx] = value

        parent = (idx - 1) // 2
        while parent >= 0:
            self.nodes[parent] += change
            parent = (parent - 1) // 2

    def add(self, value, data):
        self.data[self.count] = data
        self.update(self.count, value)

        self.count = (self.count + 1) % self.size
        self.real_size = min(self.size, self.real_size + 1)

    def get(self, cumsum):
        assert cumsum <= self.total

        idx = 0
        while 2 * idx + 1 < len(self.nodes):
            left, right = 2*idx + 1, 2*idx + 2

            if cumsum <= self.nodes[left]:
                idx = left
            else:
                idx = right
                cumsum = cumsum - self.nodes[left]

        data_idx = idx - self.size + 1

        return data_idx, self.nodes[idx], self.data[data_idx]

    def __repr__(self):
        return f"SumTree(nodes={self.nodes.__repr__()}, data={self.data.__repr__()})"

class PrioritizedReplayBuffer:
    def __init__(self, device, state_size, action_size, buffer_size, eps=1e-2, alpha=0.6, beta=0.4):
        self.tree = SumTree(size=buffer_size)

        self.device = device

        # PER params
        self.eps = eps  # minimal priority, prevents zero probabilities
        self.alpha = alpha  # determines how much prioritization is used, α = 0 corresponding to the uniform case
        self.beta = beta  # determines the amount of importance-sampling correction, b = 1 fully compensate for the non-uniform probabilities
        self.max_priority = eps  # priority for new samples, init as eps

        # transition: state, action, reward, next_state, done
        self.state = torch.empty(buffer_size, state_size, dtype=torch.float)
        self.action = torch.empty(buffer_size, action_size, dtype=torch.float)
        self.reward = torch.empty(buffer_size, dtype=torch.float)
        self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float)
        self.done = torch.empty(buffer_size, dtype=torch.int)

        self.count = 0
        self.real_size = 0
        self.size = buffer_size

    def add(self, transition):
        state, action, reward, next_state, done = transition

        # store transition index with maximum priority in sum tree
        self.tree.add(self.max_priority, self.count)

        # store transition in the buffer
        self.state[self.count] = torch.as_tensor(state)
        self.action[self.count] = torch.as_tensor(action)
        self.reward[self.count] = torch.as_tensor(reward)
        self.next_state[self.count] = torch.as_tensor(next_state)
        self.done[self.count] = torch.as_tensor(done)

        # update counters
        self.count = (self.count + 1) % self.size
        self.real_size = min(self.size, self.real_size + 1)

    def sample(self, batch_size):
        assert self.real_size >= batch_size, "buffer contains less samples than batch size"

        sample_idxs, tree_idxs = [], []
        priorities = torch.empty(batch_size, 1, dtype=torch.float)

        # To sample a minibatch of size k, the range [0, p_total] is divided equally into k ranges.
        # Next, a value is uniformly sampled from each range. Finally the transitions that correspond
        # to each of these sampled values are retrieved from the tree. (Appendix B.2.1, Proportional prioritization)
        segment = self.tree.total / batch_size
        for i in range(batch_size):
            a, b = segment * i, segment * (i + 1)

            cumsum = random.uniform(a, b)
            # sample_idx is a sample index in buffer, needed further to sample actual transitions
            # tree_idx is a index of a sample in the tree, needed further to update priorities
            tree_idx, priority, sample_idx = self.tree.get(cumsum)

            priorities[i] = priority
            tree_idxs.append(tree_idx)
            sample_idxs.append(sample_idx)

        # Concretely, we define the probability of sampling transition i as P(i) = p_i^α / \sum_{k} p_k^α
        # where p_i > 0 is the priority of transition i. (Section 3.3)
        probs = priorities / self.tree.total

        # The estimation of the expected value with stochastic updates relies on those updates corresponding
        # to the same distribution as its expectation. Prioritized replay introduces bias because it changes this
        # distribution in an uncontrolled fashion, and therefore changes the solution that the estimates will
        # converge to (even if the policy and state distribution are fixed). We can correct this bias by using
        # importance-sampling (IS) weights w_i = (1/N * 1/P(i))^β that fully compensates for the non-uniform
        # probabilities P(i) if β = 1. These weights can be folded into the Q-learning update by using w_i * δ_i
        # instead of δ_i (this is thus weighted IS, not ordinary IS, see e.g. Mahmood et al., 2014).
        # For stability reasons, we always normalize weights by 1/maxi wi so that they only scale the
        # update downwards (Section 3.4, first paragraph)
        weights = (self.real_size * probs) ** -self.beta

        # As mentioned in Section 3.4, whenever importance sampling is used, all weights w_i were scaled
        # so that max_i w_i = 1. We found that this worked better in practice as it kept all weights
        # within a reasonable range, avoiding the possibility of extremely large updates. (Appendix B.2.1, Proportional prioritization)
        weights = weights / weights.max()

        batch = (
            self.state[sample_idxs].to(self.device),
            self.action[sample_idxs].to(self.device),
            self.reward[sample_idxs].to(self.device),
            self.next_state[sample_idxs].to(self.device),
            self.done[sample_idxs].to(self.device)
        )
        return batch, weights, tree_idxs

    def update_priorities(self, data_idxs, priorities):
        if isinstance(priorities, torch.Tensor):
            priorities = priorities.detach().cpu().numpy()

        for data_idx, priority in zip(data_idxs, priorities):
            # The first variant we consider is the direct, proportional prioritization where p_i = |δ_i| + eps,
            # where eps is a small positive constant that prevents the edge-case of transitions not being
            # revisited once their error is zero. (Section 3.3)
            priority = (priority + self.eps) ** self.alpha

            self.tree.update(data_idx, priority)
            self.max_priority = max(self.max_priority, priority)

    def __len__(self):
        return self.real_size

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class OrnsteinUhlenbeckActionNoise:
    def __init__(self, mu, sigma, theta=.15, dt=1e-2, x0=None):
        self.theta = theta
        self.mu = mu
        self.sigma = sigma
        self.dt = dt
        self.x0 = x0
        self.reset()

    def noise(self):
        x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt \
            + self.sigma * np.sqrt(self.dt) * np.random.normal(size=self.mu.shape)
        self.x_prev = x
        return x

    def reset(self):
        self.x_prev = self.x0 if self.x0 is not None else np.zeros_like(self.mu)

    def __repr__(self):
        return 'OrnsteinUhlenbeckActionNoise(mu={}, sigma={})'.format(self.mu, self.sigma)

class Actor(nn.Module):
    def __init__(self, state_dim_actor, action_dim):

        super(Actor, self).__init__()

        self.bn1 = nn.BatchNorm1d(state_dim_actor)
        self.linear1 = nn.Linear(state_dim_actor, 512)
        self.linear2 = nn.Linear(512, 256)
        self.mu = nn.Linear(256, action_dim)

        # Weight Init
        fan_in_uniform_init(self.linear1.weight)
        fan_in_uniform_init(self.linear1.bias)
        fan_in_uniform_init(self.linear2.weight)
        fan_in_uniform_init(self.linear2.bias)
        nn.init.uniform_(self.mu.weight, -3e-3, 3e-3)
        nn.init.uniform_(self.mu.bias, -3e-4, 3e-4)
            
    def forward(self, inputs):

        x = inputs
        
        # Layer 1
        x = self.bn1(x)
        x = self.linear1(x)
        x = F.tanh(x)

        # Layer 2
        x = self.linear2(x)
        x = F.tanh(x)

        # Output
        mu = torch.tanh(self.mu(x))
        return mu

class Critic(nn.Module):
    def __init__(self, state_dim_critic, action_dim):

        super(Critic, self).__init__()

        self.bn1 = nn.BatchNorm1d(state_dim_critic)
        self.linear1 = nn.Linear(state_dim_critic, 512)
        self.linear2 = nn.Linear(512 + action_dim, 256)
        self.V = nn.Linear(256, 1)

        # Weight Init
        fan_in_uniform_init(self.linear1.weight)
        fan_in_uniform_init(self.linear1.bias)
        fan_in_uniform_init(self.linear2.weight)
        fan_in_uniform_init(self.linear2.bias)
        nn.init.uniform_(self.V.weight, -3e-3, 3e-3)
        nn.init.uniform_(self.V.bias, -3e-4, 3e-4)
            
    def forward(self, inputs, actions):

        x = inputs
        
        # Layer 1
        x = self.bn1(x)
        x = self.linear1(x)
        x = F.tanh(x)

        # Layer 2
        x = torch.cat((x, actions), 1)  # Insert the actions
        x = self.linear2(x)
        x = F.tanh(x)

        # Output
        V = self.V(x)
        return V

class TD3:
    def __init__(self, device, state_dim_actor, state_dim_critic, action_dim, lr_actor, lr_critic, gamma, tau, memory_capacity, batch_size, K_epochs, policy_delay):

        self.device = device
        self.gamma = gamma
        self.tau = tau
        self.batch_size = batch_size
        self.K_epochs = K_epochs
        self.policy_delay = policy_delay
        self.ou_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(action_dim), sigma=float(0.2) * np.ones(action_dim))
        
        self.memory = PrioritizedReplayBuffer(device, state_dim_actor, action_dim, memory_capacity)

        self.update_counter = 0

        # Define the actor
        self.actor = Actor(state_dim_actor, action_dim).to(self.device)
        self.actor_target = Actor(state_dim_actor, action_dim).to(self.device)

        # Define the twin critics
        self.critic_1 = Critic(state_dim_critic, action_dim).to(self.device)
        self.critic_target_1 = Critic(state_dim_critic, action_dim).to(self.device)
        self.critic_2 = Critic(state_dim_critic, action_dim).to(self.device)
        self.critic_target_2 = Critic(state_dim_critic, action_dim).to(self.device)

        # Define the optimizers for all networks
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr_actor, weight_decay=0.01)
        self.critic_optimizer_1 = torch.optim.Adam(self.critic_1.parameters(), lr=lr_critic, weight_decay=0.01)
        self.critic_optimizer_2 = torch.optim.Adam(self.critic_2.parameters(), lr=lr_critic, weight_decay=0.01)

        # Make sure all targets are with the same weight
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic_target_1.load_state_dict(self.critic_1.state_dict())
        self.critic_target_2.load_state_dict(self.critic_2.state_dict())

        self.actor.eval()
        self.actor_target.eval()
        self.critic_1.eval()
        self.critic_target_1.eval()
        self.critic_2.eval()
        self.critic_target_2.eval()

    def select_action(self, state, action_noise=None):
        with torch.no_grad():
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            action = self.actor(state)
            # During training we add noise for exploration
            if action_noise is not None:
                noise = torch.Tensor(action_noise.noise()).to(self.device)
                action += noise
            action = action.detach()
        return action.cpu().numpy().flatten()
    
    def update(self):
        """
        Updates the parameters/networks of the agent according to a prioritized batch.
        Includes:
            1. Sampling with priority and importance weights
            2. Updating critic using weighted loss
            3. Updating actor with policy gradient
            4. Updating target networks via soft updates
        """
        self.actor.train()
        self.actor_target.train()
        self.critic_1.train()
        self.critic_target_1.train()
        self.critic_2.train()
        self.critic_target_2.train()

        total_value_loss = 0.0
        total_policy_loss = 0.0
        num_policy_updates = 0

        self.update_counter += 1

        for j in range(self.K_epochs):

            # Sample prioritized batch
            batch, weights, tree_idxs = self.memory.sample(self.batch_size)
            (state_batch, action_batch, reward_batch, next_state_batch, done_batch) = batch

            # Ensure proper tensor shapes
            reward_batch = reward_batch.view(-1, 1)
            done_batch = done_batch.view(-1, 1)

            # Compute target Q-values
            with torch.no_grad():
                next_actions = self.actor_target(next_state_batch)
                target_q1 = self.critic_target_1(next_state_batch, next_actions)
                target_q2 = self.critic_target_2(next_state_batch, next_actions)
                target_q = torch.min(target_q1, target_q2)
                expected_q_values = reward_batch + (1.0 - done_batch) * self.gamma * target_q

            # Critic 1 update
            self.critic_optimizer_1.zero_grad()
            current_q1 = self.critic_1(state_batch, action_batch)
            loss_q1 = (weights * (current_q1 - expected_q_values).pow(2)).mean()
            loss_q1.backward()
            self.critic_optimizer_1.step()

            # Critic 2 update
            self.critic_optimizer_2.zero_grad()
            current_q2 = self.critic_2(state_batch, action_batch)
            loss_q2 = (weights * (current_q2 - expected_q_values).pow(2)).mean()
            loss_q2.backward()
            self.critic_optimizer_2.step()

            total_value_loss += 0.5 * (loss_q1.item() + loss_q2.item())

            # Update priorities in replay buffer
            td_errors = current_q1 - expected_q_values
            new_priorities = td_errors.detach().abs().cpu().numpy().squeeze()
            self.memory.update_priorities(tree_idxs, new_priorities)

        if self.update_counter % self.policy_delay == 0:
            # Actor update
            self.actor_optimizer.zero_grad()
            actor_loss = -self.critic_1(state_batch, self.actor(state_batch)).mean()
            actor_loss.backward()
            self.actor_optimizer.step()

            # Soft update targets
            soft_update(self.actor_target, self.actor, self.tau)
            soft_update(self.critic_target_1, self.critic_1, self.tau)
            soft_update(self.critic_target_2, self.critic_2, self.tau)

            total_policy_loss += actor_loss.item()
            num_policy_updates += 1

        self.actor.eval()
        self.actor_target.eval()
        self.critic_1.eval()
        self.critic_target_1.eval()
        self.critic_2.eval()
        self.critic_target_2.eval()

        avg_value_loss = total_value_loss / self.K_epochs
        avg_policy_loss = total_policy_loss / max(1, num_policy_updates)

        return avg_value_loss, avg_policy_loss


    def save(self, checkpoint_path):
        torch.save(self.actor.state_dict(), checkpoint_path + "_actor.pth")
        torch.save(self.actor_target.state_dict(), checkpoint_path + "_actor_target.pth")
        torch.save(self.critic_1.state_dict(), checkpoint_path + "_critic_1.pth")
        torch.save(self.critic_target_1.state_dict(), checkpoint_path + "_critic_target_1.pth")
        torch.save(self.critic_2.state_dict(), checkpoint_path + "_critic_2.pth")
        torch.save(self.critic_target_2.state_dict(), checkpoint_path + "_critic_target_2.pth")
   
    def load(self, checkpoint_path):
        checkpoint_actor = torch.load(checkpoint_path + "_actor.pth", map_location=self.device, weights_only=True)
        checkpoint_actor_target = torch.load(checkpoint_path + "_actor_target.pth", map_location=self.device, weights_only=True)
        checkpoint_critic_1 = torch.load(checkpoint_path + "_critic_1.pth", map_location=self.device, weights_only=True)
        checkpoint_critic_target_1 = torch.load(checkpoint_path + "_critic_target_1.pth", map_location=self.device, weights_only=True)
        checkpoint_critic_2 = torch.load(checkpoint_path + "_critic_2.pth", map_location=self.device, weights_only=True)
        checkpoint_critic_target_2 = torch.load(checkpoint_path + "_critic_target_2.pth", map_location=self.device, weights_only=True)

        self.actor.load_state_dict(checkpoint_actor)
        self.actor_target.load_state_dict(checkpoint_actor_target)
        self.critic_1.load_state_dict(checkpoint_critic_1)
        self.critic_target_1.load_state_dict(checkpoint_critic_target_1)
        self.critic_2.load_state_dict(checkpoint_critic_2)
        self.critic_target_2.load_state_dict(checkpoint_critic_target_2)

        self.actor.to(self.device)
        self.actor_target.to(self.device)
        self.critic_1.to(self.device)
        self.critic_target_1.to(self.device)
        self.critic_2.to(self.device)
        self.critic_target_2.to(self.device)