from copy import deepcopy
import numpy as np
import torch
from torch import optim
from torch.distributions import Normal
from modules import SACAgent

to_zeros = lambda x: torch.zeros(x, dtype=torch.float32)

class ReplayBuffer:

    def __init__(self, buffer_size, n_envs, observation_space, action_space, envs):
        self._ldba = deepcopy(envs.call('unwrapped')[0].get_ldba())
        self.jump_mask = torch.tensor(envs.call('unwrapped')[0].get_jump_mask())
        self.n_states = self.jump_mask.shape[0]
        self.observations = torch.zeros((buffer_size, n_envs, *observation_space.shape), dtype=torch.float32)
        self.next_observations = torch.zeros((buffer_size, n_envs, *observation_space.shape), dtype=torch.float32)
        self.actions = to_zeros((buffer_size, n_envs, *action_space.shape))
        self.rewards, self.dones = to_zeros((buffer_size, n_envs)), to_zeros((buffer_size, n_envs))
        self.infos, self.next_infos = None, None
        self.buffer_size, self.n_envs = buffer_size, n_envs
        self.size, self.pos = 0, 0

    def add(self, obs, next_obs, action, reward, done, infos, next_infos):
        buffers = (self.observations, self.next_observations, self.actions, self.rewards, self.dones)
        new_data = (obs, next_obs, action, reward, done)
        for buf, data in zip(buffers, new_data):
            buf[self.pos] = data.clone()
        if self.infos is None:
            self.infos = {k: to_zeros((self.buffer_size, self.n_envs, *v.shape)) for k, v in infos.items()}
            self.next_infos = {k: to_zeros((self.buffer_size, self.n_envs, *v.shape)) for k, v in infos.items()}
        for k in self.infos:
            self.infos[k][self.pos] = infos[k].clone()
            self.next_infos[k][self.pos] = next_infos[k].clone()
        self.pos = (self.pos + 1) % self.buffer_size
        self.size = min(self.size + 1, self.buffer_size)

    def sample(self, batch_size):
        idxs = tuple(torch.randint(low=0, high=h, size=(batch_size,)) for h in (self.size, self.n_envs))
        infos = {k: v[idxs] for k, v in self.infos.items()}
        next_infos = {k: v[idxs] for k, v in self.next_infos.items()}
        return self.observations[idxs], self.actions[idxs], self.next_observations[idxs], self.dones[idxs], self.rewards[idxs], infos, next_infos

class SAC:

    def __init__(self, envs, args):
        self.args = args
        self.jump_mask = torch.tensor(envs.call('unwrapped')[0].get_jump_mask(), device=args.device)
        self.aps = envs.call('unwrapped')[0].aps.copy()
        self.n_jumps = self.jump_mask.shape[-1]
        self.n_states = self.jump_mask.shape[0]
        # pdb.set_trace()
        self.agent = SACAgent(
            np.array(envs.unwrapped.single_observation_space.shape).prod() + self.n_states,
            envs.unwrapped.single_action_space,
            args
        ).to(args.device)
        self.q_optimizer = optim.Adam(list(self.agent.qf1.parameters()) + list(self.agent.qf2.parameters()), lr=args.learning_rate)
        self.actor_optimizer = optim.Adam(list(self.agent.actor.parameters()), lr=args.learning_rate)
        if self.args.autotune:
            self.alpha_optimizer = optim.Adam([self.agent.log_alpha], lr=args.learning_rate)
        self.rb = ReplayBuffer(args.buffer_size, args.num_envs, envs.single_observation_space, envs.single_action_space, envs)
        self.update_frequency = 1
        self.train_metrics = {}

    def anneal(self, update, num_updates):
        pass  # nothing to anneal!

    def store(self, obs, info, done, action, reward, logprob, value, next_obs, next_info, next_done, step):
        self.rb.add(obs, next_obs, action, reward, next_done, info, next_info)

    def get_action_and_value(self, x, action=None, info={}):
        ldba_obs = torch.eye(self.n_states, device=self.args.device)[info['ldba_obs'].long()]
        x = torch.cat([x, ldba_obs], -1)
        action, logprob, _ = self.get_action(x, info)
        dummy_entropy, dummy_value = action.sum(-1) * 0, action.sum(-1, keepdim=True) * 0
        return action, logprob, dummy_entropy, dummy_value

    def eval_action(self, x, info={}):
        ldba_obs = torch.eye(self.n_states, device=self.args.device)[info['ldba_obs'].long()]
        x = torch.cat([x, ldba_obs], -1)
        mean, _ = self.agent.actor(x)
        cutoff = mean.shape[-1]-self.n_jumps
        mean, jump_logit = mean[:, :cutoff], mean[:, cutoff:]
        jump_probs = torch.nn.functional.softmax(jump_logit, dim=-1)
        jump_probs = jump_probs * self.jump_mask[info['ldba_obs'].long().squeeze()].float()
        jump = jump_probs.argmax(-1)
        jump = torch.eye(self.n_jumps, device=self.args.device)[jump.long()].float()
        return torch.cat([torch.tanh(mean), jump], dim=-1)

    def get_action(self, x, info):
        mean, log_std = self.agent.actor(x)
        cutoff = mean.shape[-1]-self.n_jumps
        mean, jump_logit = mean[:, :cutoff], mean[:, cutoff:]
        std = log_std[:, :cutoff].exp()
        action_probs = Normal(mean, std)
        x_t = action_probs.rsample()
        y_t = torch.tanh(x_t)
        act_log_prob = action_probs.log_prob(x_t)
        act_log_prob -= torch.log((1 - y_t.pow(2)) + 1e-6)
        act_log_prob = act_log_prob.sum(1, keepdim=True)

        jump_probs = torch.nn.functional.softmax(jump_logit, dim=-1)
        jump_probs = jump_probs * self.jump_mask[info['ldba_obs'].long().squeeze()].float()
        jump_probs /= jump_probs.sum(-1, keepdim=True) # renormalize
        jump_probs = jump_probs * (1-1e-6) + (1e-6/self.n_jumps)
        jump = torch.multinomial(jump_probs, 1, replacement=True)[:, 0]
        if self.n_jumps == 1:
            jump_log_probs = jump_probs[jump].log()
        else:
            jump_log_probs = jump_probs.gather(1, jump.unsqueeze(1)).clamp(min=1e-6).log()

        jump = torch.eye(self.n_jumps, device=self.args.device)[jump.long()].float()

        action = torch.cat([y_t, jump], dim=-1)
        mean_jump = jump_probs.argmax(-1)
        mean_jump = torch.eye(self.n_jumps, device=self.args.device)[mean_jump.long()].float()
        mean = torch.cat([torch.tanh(mean), mean_jump], dim=-1)
        log_prob = act_log_prob + jump_log_probs
        return action, log_prob, mean

    def train(self, step, new_q_value, batch=None):
        if batch is None:
            batch = self.rb.sample(self.args.batch_size)

        obs, acts, next_obs, dones, rewards, infos, next_infos = batch
        obs = obs.to(self.args.device)
        acts = acts.to(self.args.device)
        next_obs = next_obs.to(self.args.device)
        rewards = rewards.to(self.args.device)
        infos['ldba_obs'] = infos['ldba_obs'].to(self.args.device)
        next_infos['ldba_obs'] = next_infos['ldba_obs'].to(self.args.device)

        ldba_obs = torch.eye(self.n_states, device=self.args.device)[infos['ldba_obs'].long().squeeze()]
        
        obs = torch.cat([obs, ldba_obs], -1)
        next_ldba_obs = torch.eye(self.n_states, device=self.args.device)[next_infos['ldba_obs'].long().squeeze()]
        next_obs = torch.cat([next_obs, next_ldba_obs], -1)

        with torch.no_grad():
            next_state_actions, next_state_log_pi, _ = self.get_action(next_obs, next_infos)
            qf1_next_target = self.agent.qf1_target(next_obs, next_state_actions)
            qf2_next_target = self.agent.qf2_target(next_obs, next_state_actions)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - (self.agent.alpha * next_state_log_pi)
            discount = self.args.gamma ** ((rewards.flatten() > 0).float() if 'accepting' not in infos else infos['accepting'])
            next_q_value = (rewards.flatten() + discount * (min_qf_next_target).view(-1))
            # could potentially supervise the value for sink LDBA states

        qf1_a_values = self.agent.qf1(obs, acts).view(-1)
        qf2_a_values = self.agent.qf2(obs, acts).view(-1)
        qf1_loss = torch.nn.functional.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = torch.nn.functional.mse_loss(qf2_a_values, next_q_value)
        qf_loss = qf1_loss + qf2_loss
        self.q_optimizer.zero_grad()
        qf_loss.backward()
        critic_grad = np.mean([p.grad.abs().mean().item() for _, p in self.agent.qf1.named_parameters()])
        self.q_optimizer.step()

        metrics = {
            "losses/qf1_values": qf1_a_values.mean().item(),
            "losses/qf2_values": qf2_a_values.mean().item(),
            "losses/qf1_loss": qf1_loss.item(),
            "losses/qf2_loss": qf2_loss.item(),
            "losses/qf_loss": qf_loss.item() / 2.0,
            "losses/critic_grad": critic_grad,
            "losses/alpha": self.agent.alpha
        }

        if step % self.args.policy_frequency == 0:
            for _ in range(self.args.policy_frequency):
                pi, log_pi, _ = self.get_action(obs, infos)
                qf1_pi = self.agent.qf1(obs, pi)
                qf2_pi = self.agent.qf2(obs, pi)
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                actor_loss = ((self.agent.alpha * log_pi) - min_qf_pi).mean()

                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                actor_grad = np.mean([p.grad.abs().mean().item() for _, p in self.agent.actor.named_parameters()])
                self.actor_optimizer.step()
                metrics.update({"losses/actor_loss": actor_loss.item()})
                metrics.update({"losses/log_pi": log_pi.mean().item()})
                metrics.update({"losses/actor_grad": actor_grad})
                
                if self.args.autotune:
                    with torch.no_grad():
                        _, log_pi, _ = self.get_action(obs, infos)
                    alpha_loss = (-self.agent.log_alpha.exp() * (log_pi + self.agent.target_entropy)).mean()
                    self.alpha_optimizer.zero_grad()
                    alpha_loss.backward()
                    self.alpha_optimizer.step()
                    self.agent.alpha = self.agent.log_alpha.exp().item()

        if step % self.args.target_network_frequency == 0:
            for param, target_param in zip(self.agent.qf1.parameters(), self.agent.qf1_target.parameters()):
                target_param.data.copy_(self.args.tau * param.data + (1 - self.args.tau) * target_param.data)
            for param, target_param in zip(self.agent.qf2.parameters(), self.agent.qf2_target.parameters()):
                target_param.data.copy_(self.args.tau * param.data + (1 - self.args.tau) * target_param.data)

        return metrics

    def _sample_batch(self):
        return self.rb.sample(self.args.batch_size)

    def save(self, working_dir):
        torch.save(self.agent.state_dict(), f'{working_dir}/agent.pth')


class CycleSAC:

    def __init__(self, envs, args):
        self.args = args
        self.jump_mask = torch.tensor(envs.call('unwrapped')[0].get_jump_mask(), device=args.device)
        self.aps = envs.call('unwrapped')[0].aps.copy()
        self.n_jumps = self.jump_mask.shape[-1]
        self.n_states = self.jump_mask.shape[0]

        self.agent = SACAgent(
            np.array(envs.unwrapped.single_observation_space.shape).prod() + self.n_states,
            envs.unwrapped.single_action_space,
            args
        ).to(args.device)
        self.q_optimizer = optim.Adam(list(self.agent.qf1.parameters()) + list(self.agent.qf2.parameters()), lr=args.learning_rate)
        self.actor_optimizer = optim.Adam(list(self.agent.actor.parameters()), lr=args.learning_rate)
        if self.args.autotune:
            self.alpha_optimizer = optim.Adam([self.agent.log_alpha], lr=args.learning_rate)
        self.rb = ReplayBuffer(args.buffer_size, args.num_envs, envs.single_observation_space, envs.single_action_space, envs)
        self.update_frequency = 1
        self.train_metrics = {}

    def anneal(self, update, num_updates):
        pass  # nothing to anneal!

    def store(self, obs, info, done, action, reward, logprob, value, next_obs, next_info, next_done, step):
        self.rb.add(obs, next_obs, action, reward, next_done, info, next_info)

    def get_action_and_value(self, x, action=None, info={}):
        ldba_obs = torch.eye(self.n_states, device=self.args.device)[info['ldba_obs'].long()]
        x = torch.cat([x, ldba_obs], -1)
        action, logprob, _ = self.get_action(x, info)
        dummy_entropy, dummy_value = action.sum(-1) * 0, action.sum(-1, keepdim=True) * 0
        return action, logprob, dummy_entropy, dummy_value

    def eval_action(self, x, info={}):
        ldba_obs = torch.eye(self.n_states, device=self.args.device)[info['ldba_obs'].long()]
        x = torch.cat([x, ldba_obs], -1)
        mean, _ = self.agent.actor(x)
        cutoff = mean.shape[-1]-self.n_jumps
        mean, jump_logit = mean[:, :cutoff], mean[:, cutoff:]
        jump_probs = torch.nn.functional.softmax(jump_logit, dim=-1)
        jump_probs = jump_probs * self.jump_mask[info['ldba_obs'].long().squeeze()].float()
        jump = jump_probs.argmax(-1)
        jump = torch.eye(self.n_jumps, device=self.args.device)[jump.long()].float()
        return torch.cat([torch.tanh(mean), jump], dim=-1)

    def get_action(self, x, info):
        mean, log_std = self.agent.actor(x)
        cutoff = mean.shape[-1]-self.n_jumps
        mean, jump_logit = mean[:, :cutoff], mean[:, cutoff:]
        std = log_std[:, :cutoff].exp()
        action_probs = Normal(mean, std)
        x_t = action_probs.rsample()
        y_t = torch.tanh(x_t)
        act_log_prob = action_probs.log_prob(x_t)
        act_log_prob -= torch.log((1 - y_t.pow(2)) + 1e-6)
        act_log_prob = act_log_prob.sum(1, keepdim=True)

        jump_probs = torch.nn.functional.softmax(jump_logit, dim=-1)
        jump_probs = jump_probs * self.jump_mask[info['ldba_obs'].long().squeeze()].float()
        jump_probs /= jump_probs.sum(-1, keepdim=True) # renormalize
        jump_probs = jump_probs * (1-1e-6) + (1e-6/self.n_jumps)
        jump = torch.multinomial(jump_probs, 1, replacement=True)[:, 0]

        if self.n_jumps == 1:
            jump_log_probs = jump_probs[jump].log()
        else:
            jump_log_probs = jump_probs.gather(1, jump.unsqueeze(1)).clamp(min=1e-6).log()

        jump = torch.eye(self.n_jumps, device=self.args.device)[jump.long()].float()

        action = torch.cat([y_t, jump], dim=-1)
        mean_jump = jump_probs.argmax(-1)
        mean_jump = torch.eye(self.n_jumps, device=self.args.device)[mean_jump.long()].float()
        mean = torch.cat([torch.tanh(mean), mean_jump], dim=-1)
        log_prob = act_log_prob + jump_log_probs
        return action, log_prob, mean

    def train(self, step, new_q_value, batch=None):
        if batch is None:
            batch = self.rb.sample(self.args.batch_size)

        obs, acts, next_obs, dones, rewards, infos, next_infos = batch
        obs = obs.to(self.args.device)
        acts = acts.to(self.args.device)
        next_obs = next_obs.to(self.args.device)
        rewards = rewards.to(self.args.device)
        infos['ldba_obs'] = infos['ldba_obs'].to(self.args.device)
        next_infos['ldba_obs'] = next_infos['ldba_obs'].to(self.args.device)

        ldba_obs = torch.eye(self.n_states, device=self.args.device)[infos['ldba_obs'].long().squeeze()]
        obs = torch.cat([obs, ldba_obs], -1)

        next_ldba_obs = torch.eye(self.n_states, device=self.args.device)[next_infos['ldba_obs'].long().squeeze()]
        next_obs = torch.cat([next_obs, next_ldba_obs], -1)

        with torch.no_grad():
            next_state_actions, next_state_log_pi, _ = self.get_action(next_obs, next_infos)
            qf1_next_target = self.agent.qf1_target(next_obs, next_state_actions)
            qf2_next_target = self.agent.qf2_target(next_obs, next_state_actions)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - (self.agent.alpha * next_state_log_pi)
            discount = self.args.gamma ** ((rewards.flatten() > 0).float() if 'accepting' not in infos else infos['accepting'])
            next_q_value = (rewards.flatten() + discount * (min_qf_next_target).view(-1))

            # for flat cycle
            if self.n_states == 6:
                after_indices = (obs[:, 5] == 1).nonzero(as_tuple=True)[0]
            # for cheetah flip and carlo
            else:
                after_indices = (obs[:, -2] == 1).nonzero(as_tuple=True)[0]

            next_q_value[after_indices] = torch.tensor(new_q_value, dtype=next_q_value.dtype, device=self.args.device)
            # could potentially supervise the value for sink LDBA states

        qf1_a_values = self.agent.qf1(obs, acts).view(-1)
        qf2_a_values = self.agent.qf2(obs, acts).view(-1)
        qf1_loss = torch.nn.functional.mse_loss(qf1_a_values, next_q_value)
        qf2_loss = torch.nn.functional.mse_loss(qf2_a_values, next_q_value)
        qf_loss = qf1_loss + qf2_loss
        self.q_optimizer.zero_grad()
        qf_loss.backward()
        critic_grad = np.mean([p.grad.abs().mean().item() for _, p in self.agent.qf1.named_parameters()])
        self.q_optimizer.step()

        metrics = {
            "losses/qf1_values": qf1_a_values.mean().item(),
            "losses/qf2_values": qf2_a_values.mean().item(),
            "losses/qf1_loss": qf1_loss.item(),
            "losses/qf2_loss": qf2_loss.item(),
            "losses/qf_loss": qf_loss.item() / 2.0,
            "losses/critic_grad": critic_grad,
            "losses/alpha": self.agent.alpha
        }

        if step % self.args.policy_frequency == 0:
            for _ in range(self.args.policy_frequency):
                pi, log_pi, _ = self.get_action(obs, infos)
                qf1_pi = self.agent.qf1(obs, pi)
                qf2_pi = self.agent.qf2(obs, pi)
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
                actor_loss = ((self.agent.alpha * log_pi) - min_qf_pi).mean()

                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                actor_grad = np.mean([p.grad.abs().mean().item() for _, p in self.agent.actor.named_parameters()])
                self.actor_optimizer.step()
                metrics.update({"losses/actor_loss": actor_loss.item()})
                metrics.update({"losses/log_pi": log_pi.mean().item()})
                metrics.update({"losses/actor_grad": actor_grad})
                
                if self.args.autotune:
                    with torch.no_grad():
                        _, log_pi, _ = self.get_action(obs, infos)
                    alpha_loss = (-self.agent.log_alpha.exp() * (log_pi + self.agent.target_entropy)).mean()
                    self.alpha_optimizer.zero_grad()
                    alpha_loss.backward()
                    self.alpha_optimizer.step()
                    self.agent.alpha = self.agent.log_alpha.exp().item()

        if step % self.args.target_network_frequency == 0:
            for param, target_param in zip(self.agent.qf1.parameters(), self.agent.qf1_target.parameters()):
                target_param.data.copy_(self.args.tau * param.data + (1 - self.args.tau) * target_param.data)
            for param, target_param in zip(self.agent.qf2.parameters(), self.agent.qf2_target.parameters()):
                target_param.data.copy_(self.args.tau * param.data + (1 - self.args.tau) * target_param.data)

        return metrics

    def _sample_batch(self):
        return self.rb.sample(self.args.batch_size)

    def save(self, working_dir):
        torch.save(self.agent.state_dict(), f'{working_dir}/agent.pth')
