import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import math
from collections import deque
import gc

device = torch.device('cpu')
class RunningNormNormalizer:
    def __init__(self, observation_size):
        # Pre-allocate memory and move to the device (likely cuda)
        self.mean = np.zeros((observation_size))
        self.var = np.zeros((observation_size))
        self.count = 1

    def update_stats(self, observation):
        self.count += 1
        delta = observation - self.mean
        self.mean += delta/self.count
        delta2 = observation - self.mean
        self.var += delta * delta2

    def normalize(self, observation):
        self.update_stats(observation)
        return (observation - self.mean) / (np.sqrt(self.var) + 1e-8)

class EpochBuffer:
    def __init__(self, observation_size, action_size, max_size):
        self.size = max_size
        self.current_idx = 0
        self.filled = 0

        self.states = np.empty((max_size, observation_size))
        self.actions = np.empty((max_size, action_size))
        self.rewards = np.empty((max_size, 1))
        self.next_states = np.empty((max_size, observation_size))
        self.dones = np.empty((max_size, 1))

        self.device = device

    def add(self, state, action, reward, next_state, done):
        self.states[self.current_idx] = state
        self.actions[self.current_idx] = action
        self.rewards[self.current_idx] = reward
        self.next_states[self.current_idx] = next_state
        self.dones[self.current_idx] = done

        self.current_idx = (self.current_idx + 1) % self.size
        self.filled = min(self.size, self.filled + 1)

    def get(self):
        self.current_idx = 0
        self.filled = 0
        return self.states, self.actions, self.rewards, self.next_states, self.dones

    def clear(self):
        self.current_idx = 0
        self.filled = 0

class OptimizedReplayBuffer:
    def __init__(self, observation_size, action_size, max_size, device=torch.device('cpu')):
        self.size = max_size
        self.current_idx = 0
        self.filled = 0

        # Pre-allocate memory and move to the device (likely cuda)
        self.states = torch.empty((max_size, observation_size), device=device)
        self.actions = torch.empty((max_size, action_size), device=device)
        self.rewards = torch.empty((max_size, 1), device=device)
        self.next_states = torch.empty((max_size, observation_size), device=device)
        self.dones = torch.empty((max_size, 1), device=device)

        self.device = device

    def add(self, state, action, reward, next_state, done):
        self.states[self.current_idx] = torch.tensor(state, device=self.device)
        self.actions[self.current_idx] = torch.tensor(action, device=self.device)
        self.rewards[self.current_idx] = torch.tensor(reward, device=self.device)
        self.next_states[self.current_idx] = torch.tensor(next_state, device=self.device)
        self.dones[self.current_idx] = torch.tensor(done, device=self.device)

        self.current_idx = (self.current_idx + 1) % self.size
        self.filled = min(self.size, self.filled + 1)

    def sample(self, batch_size):
        max_idx = min(self.size, self.filled)
        idxs = torch.randint(0, max_idx, (batch_size,), device=self.device)

        # Use the advanced indexing directly; no need for looping or stacking
        return (
            self.states[idxs].contiguous(),
            self.actions[idxs].contiguous(),
            self.rewards[idxs].contiguous(),
            self.next_states[idxs].contiguous(),
            self.dones[idxs].contiguous()
        )

    def clear(self):
        self.current_idx = 0
        self.filled = 0

class DequeReplayBuffer:
    def __init__(self, observation_size, action_size, max_size):
        self.buffer_size = max_size
        self.states = deque(maxlen=max_size)
        self.actions = deque(maxlen=max_size)
        self.rewards = deque(maxlen=max_size)
        self.next_states = deque(maxlen=max_size)
        self.dones = deque(maxlen=max_size)
        self.size = 0
        self.observation_size = observation_size
        self.action_size = action_size
        self.max_size = max_size

    def add(self, state, action, reward, next_state, done):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.next_states.append(next_state)
        self.dones.append(done)
        self.size = min(self.size+1, self.max_size)

    def sample(self, batch_size):
        idxs = np.random.randint(0, self.size-1, size=batch_size)
        sampled_states = torch.reshape(torch.FloatTensor(np.array([self.states[i] for i in idxs])),(batch_size,self.observation_size)).to(device)
        sampled_actions = torch.reshape(torch.FloatTensor(np.array([self.actions[i] for i in idxs])),(batch_size,self.action_size)).to(device)
        sampled_rewards = torch.reshape(torch.FloatTensor(np.array([self.rewards[i] for i in idxs])),(batch_size,1)).to(device)
        sampled_next_states = torch.reshape(torch.FloatTensor(np.array([self.next_states[i] for i in idxs])),(batch_size,self.observation_size)).to(device)
        sampled_dones = torch.reshape(torch.FloatTensor(np.array([self.dones[i] for i in idxs])),(batch_size,1)).to(device)
        return sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones
    
    def __len__(self):
        return len(self.states)

    def clear(self):
        self.states.clear()
        self.actions.clear()
        self.rewards.clear()
        self.next_states.clear()
        self.dones.clear()
        self.size = 0

class RandomBuffer(object):
    def __init__(self, state_dim, action_dim, max_size=int(1e6)):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.state = np.zeros((max_size, state_dim))
        self.action = np.zeros((max_size, action_dim))
        self.reward = np.zeros((max_size, 1))
        self.next_state = np.zeros((max_size, state_dim))
        self.dead = np.zeros((max_size, 1), dtype=np.uint8)

        self.device = device

    def add(self, state, action, reward, next_state, dead):
        self.state[self.ptr] = state
        self.action[self.ptr] = action
        self.reward[self.ptr] = reward
        self.next_state[self.ptr] = next_state
        self.dead[self.ptr] = dead

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)
        with torch.no_grad():
            return (
                torch.FloatTensor(self.state[ind]).to(self.device),
                torch.FloatTensor(self.action[ind]).to(self.device),
                torch.FloatTensor(self.reward[ind]).to(self.device),
                torch.FloatTensor(self.next_state[ind]).to(self.device),
                torch.FloatTensor(self.dead[ind]).to(self.device)
            )
    
    def clear(self):
        self.ptr = 0
        self.size = 0
        self.state = np.zeros((self.max_size, self.state_dim))
        self.action = np.zeros((self.max_size, self.action_dim))
        self.reward = np.zeros((self.max_size, 1))
        self.next_state = np.zeros((self.max_size, self.state_dim))
        self.dead = np.zeros((self.max_size, 1), dtype=np.uint8)

    def save(self):
        scaller = np.array([self.max_size, self.ptr, self.size, self.Env_with_dead], dtype=np.uint32)
        np.save("buffer/scaller.npy", scaller)
        np.save("buffer/state.npy", self.state)
        np.save("buffer/action.npy", self.action)
        np.save("buffer/reward.npy", self.reward)
        np.save("buffer/next_state.npy", self.next_state)
        np.save("buffer/dead.npy", self.dead)

    def load(self):
        scaller = np.load("buffer/scaller.npy")

        self.max_size = scaller[0]
        self.ptr = scaller[1]
        self.size = scaller[2]
        self.Env_with_dead = scaller[3]

        self.state = np.load("buffer/state.npy")
        self.action = np.load("buffer/action.npy")
        self.reward = np.load("buffer/reward.npy")
        self.next_state = np.load("buffer/next_state.npy")
        self.dead = np.load("buffer/dead.npy")


def build_net(layer_shape, activation, output_activation):
    layers = []
    for j in range(len(layer_shape) - 1):
        act = activation if j < len(layer_shape) - 2 else output_activation
        layers += [nn.Linear(layer_shape[j], layer_shape[j + 1]), act()]
    return nn.Sequential(*layers)


class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hid_shape, h_acti=nn.ReLU, o_acti=nn.ReLU, LOG_STD_MAX=2, LOG_STD_MIN=-20, sigmoid = False):
        super(Actor, self).__init__()

        layers = [state_dim] + list(hid_shape)
        if not sigmoid:
            self.a_net = build_net(layers, h_acti, o_acti)
            self.mu_layer = nn.Linear(layers[-1], action_dim)
        else:
            self.a_net = build_net(layers, h_acti, nn.Tanh)
            self.mu_layer = nn.Sequential(nn.Linear(layers[-1], action_dim), nn.Tanh())
        
        #self.mu_layer = nn.Linear(layers[-1], action_dim)
        self.log_std_layer = nn.Linear(layers[-1], action_dim)

        self.LOG_STD_MAX = LOG_STD_MAX
        self.LOG_STD_MIN = LOG_STD_MIN

    def forward(self, state, deterministic=False, with_logprob=True):
        net_out = self.a_net(state)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)
        std = torch.exp(log_std)
        dist = Normal(mu, std)

        if deterministic:
            u = mu
        else:
            u = dist.rsample()
        a = torch.tanh(u)

        if with_logprob:
            logp_pi_a = (dist.log_prob(u) - torch.log(1 - a.pow(2) + 1e-6)).sum(dim=1, keepdim=True)

        else:
            logp_pi_a = None

        return a, logp_pi_a
    
    def get_action_dual(self, state, state_dual):
        net_out = self.a_net(state)
        net_out_dual = self.a_net(state_dual)
        mu = self.mu_layer(net_out)
        mu_dual = self.mu_layer(net_out_dual)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)
    
        std = torch.exp(log_std)
        dist = Normal(mu, std)

        u_det = mu_dual
        a_det = torch.tanh(u_det)
        u = dist.rsample()
        a = torch.tanh(u)

        return a, a_det


class Q_Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hid_shape):
        super(Q_Critic, self).__init__()
        layers = [state_dim + action_dim] + list(hid_shape) + [1]

        self.Q_1 = build_net(layers, nn.ReLU, nn.Identity)
        self.Q_2 = build_net(layers, nn.ReLU, nn.Identity)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)
        q1 = self.Q_1(sa)
        q2 = self.Q_2(sa)
        return q1, q2


class SAC(object):
    def __init__(
            self,
            state_dim,
            action_dim,
            gamma=0.99,
            hid_shape=(256, 256),
            a_lr=3e-4,
            c_lr=3e-4,
            batch_size=256,
            alpha=0.2,
            adaptive_alpha=True,
            adam_params = None,
            LOG_STD_MAX=10,
            LOG_STD_MIN=-5,
            sigmoid = False,
            performance_tracking_window = 10,
            optimizer_name = "Adam"
    ):
        self.LOG_STD_MAX = LOG_STD_MAX
        self.LOG_STD_MIN = LOG_STD_MIN
        self.actor = Actor(state_dim, action_dim, hid_shape, LOG_STD_MAX=LOG_STD_MAX, LOG_STD_MIN=LOG_STD_MIN, sigmoid=sigmoid).to(device)
        self.q_critic = Q_Critic(state_dim, action_dim, hid_shape).to(device)
        if adam_params is None and optimizer_name == "Adam":
            self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=a_lr)
            self.q_critic_optimizer = torch.optim.Adam(self.q_critic.parameters(), lr=c_lr)
        elif optimizer_name == "Adam":
            self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=a_lr, betas=[0.5,0.85])
            self.q_critic_optimizer = torch.optim.Adam(self.q_critic.parameters(), lr=c_lr, betas=[0.5,0.85])
        elif optimizer_name == "AdamW":
            self.actor_optimizer = torch.optim.AdamW(self.actor.parameters(), lr=a_lr, weight_decay=0.1)
            self.q_critic_optimizer = torch.optim.AdamW(self.q_critic.parameters(), lr=c_lr, weight_decay=0.1)
        
        self.q_critic_target = copy.deepcopy(self.q_critic)
        for p in self.q_critic_target.parameters():
            p.requires_grad = False

        self.action_dim = action_dim
        self.gamma = gamma
        self.tau = 0.005
        self.batch_size = batch_size

        self.alpha = alpha
        self.adaptive_alpha = adaptive_alpha
        if adaptive_alpha:
            self.target_entropy = torch.tensor(-action_dim, dtype=float, requires_grad=True, device=device)
            self.log_alpha = torch.tensor(np.log(alpha), dtype=float, requires_grad=True, device=device)
            self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=c_lr)

        self.performance_tracking_window = performance_tracking_window
        self.recent_rewards = deque(maxlen=self.performance_tracking_window)
        self.longer_rewards = deque(maxlen=10*self.performance_tracking_window)

    def track_performance(self, reward):
        self.recent_rewards.append(reward)
        if len(self.recent_rewards) == self.performance_tracking_window:
            avg_reward = sum(self.recent_rewards) / self.performance_tracking_window
            return avg_reward
        return None
    
    def track_performance_std(self, reward):
        self.recent_rewards.append(reward)
        if len(self.recent_rewards) == self.performance_tracking_window:
            std_dev = np.std(self.recent_rewards)
            return std_dev
        return None

    def adjust_exploration(self, avg_reward, exploration_threshold):
        if avg_reward < exploration_threshold:
            self.LOG_STD_MAX += 0.2
            self.LOG_STD_MIN += 0.2
            self.alpha *= 1.1
            self.actor.LOG_STD_MAX = self.LOG_STD_MAX
            self.actor.LOG_STD_MIN = self.LOG_STD_MIN
        else:
            self.LOG_STD_MAX = max(self.LOG_STD_MAX - 0.1, 10)
            self.LOG_STD_MIN = min(self.LOG_STD_MIN - 0.1, -5)
            self.alpha *= 0.9
            self.actor.LOG_STD_MAX = self.LOG_STD_MAX
            self.actor.LOG_STD_MIN = self.LOG_STD_MIN
    
    def reset_parameters(self):
        self.actor.apply(self.weight_reset)
        self.q_critic.apply(self.weight_reset)
        self.q_critic_target.apply(self.weight_reset)
        del self.actor_optimizer
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        del self.longer_rewards
        self.longer_rewards = deque(maxlen=10*self.performance_tracking_window)

    @staticmethod
    def weight_reset(layer):
        if isinstance(layer, nn.Linear):
            layer.reset_parameters()

    def select_action(self, state, deterministic, with_logprob=False):
        with torch.no_grad():
            state = torch.FloatTensor(state.reshape(1, -1)).to(device)
            a, _ = self.actor(state, deterministic, with_logprob)
        return a.cpu().numpy().flatten()
    
    def select_action_gpu(self, state, deterministic, with_logprob=False):
        with torch.no_grad():
            a, _ = self.actor(state, deterministic, with_logprob)
        return a.detach()
    
    def select_action_dual(self, state, state_dual):
        with torch.no_grad():
            state = torch.FloatTensor(state.reshape(1, -1)).to(device)
            state_dual = torch.FloatTensor(state_dual.reshape(1, -1)).to(device)
            a, a_det = self.actor.get_action_dual(state, state_dual)
        return a.cpu().numpy().flatten(), a_det.cpu().numpy().flatten()

        
    def train(self, replay_buffer):
        s, a, r, s_prime, dead_mask = replay_buffer.sample(self.batch_size)

        with torch.no_grad():
            a_prime, log_pi_a_prime = self.actor(s_prime)
            target_Q1, target_Q2 = self.q_critic_target(s_prime, a_prime)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = r + (1 - dead_mask) * self.gamma * (
                    target_Q - self.alpha * log_pi_a_prime)  # Dead or Done is tackled by Randombuffer

        current_Q1, current_Q2 = self.q_critic(s, a)

        q_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        self.q_critic_optimizer.zero_grad()
        q_loss.backward()
        self.q_critic_optimizer.step()

        for params in self.q_critic.parameters():
            params.requires_grad = False

        a, log_pi_a = self.actor(s)
        current_Q1, current_Q2 = self.q_critic(s, a)
        Q = torch.min(current_Q1, current_Q2)

        a_loss = (self.alpha * log_pi_a - Q).mean()
        self.actor_optimizer.zero_grad()
        a_loss.backward()
        self.actor_optimizer.step()

        for params in self.q_critic.parameters():
            params.requires_grad = True
        if self.adaptive_alpha:
            alpha_loss = -(self.log_alpha * (log_pi_a + self.target_entropy).detach()).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self.alpha = self.log_alpha.exp()

        for param, target_param in zip(self.q_critic.parameters(), self.q_critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        del s, a, r, s_prime, dead_mask, a_prime, log_pi_a_prime, target_Q1, target_Q2, target_Q, current_Q1, current_Q2, q_loss, a_loss, log_pi_a, Q
        torch.cuda.empty_cache()

    def save(self, fname):
        torch.save(self.q_critic.state_dict(), fname + '_critic.pth')
        torch.save(self.actor.state_dict(), fname + '_actor.pth')

    def load_model_from_file(self, fname):
        self.q_critic.load_state_dict(torch.load(fname + '_critic.pth'))
        self.actor.load_state_dict(torch.load(fname + '_actor.pth'))

