import os
import numpy as np
import torch
import torch.nn.functional as F
import wandb
import utils
import hydra

from algorithm import Agent


class SACAgent(Agent):
    """SAC algorithm."""

    def __init__(self, obs_dim, action_dim, action_range, device, critic_cfg,
                 actor_cfg, discount, init_temperature, alpha_lr, alpha_betas,
                 actor_lr, actor_betas, actor_update_frequency, critic_lr,
                 critic_betas, critic_tau, critic_target_update_frequency,
                 batch_size, learnable_temperature):
        super().__init__()

        self.action_range = action_range
        self.device = torch.device(device)
        self.discount = discount
        self.critic_tau = critic_tau
        self.actor_update_frequency = actor_update_frequency
        self.critic_target_update_frequency = critic_target_update_frequency
        self.batch_size = batch_size
        self.learnable_temperature = learnable_temperature

        # Initialize networks and move to device
        self.critic = hydra.utils.instantiate(critic_cfg).to(self.device)
        self.critic_target = hydra.utils.instantiate(critic_cfg).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.actor = hydra.utils.instantiate(actor_cfg).to(self.device)

        # Temperature parameter α
        self.log_alpha = torch.tensor(np.log(np.clip(init_temperature, 1e-8, 1e8))).to(self.device)
        self.log_alpha.requires_grad = True

        self.target_entropy = -action_dim  # Set target entropy to -|A|

        # Optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr, betas=actor_betas)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr, betas=critic_betas)
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr, betas=alpha_betas)

        # self.train()
        # self.critic_target.train()

    def train(self, training=True):
        """Set training mode for networks."""
        self.training = training
        self.actor.train(training)
        self.critic.train(training)

    @property
    def alpha(self):
        """Exponentiate log_alpha to get alpha."""
        return self.log_alpha.exp()

    def act(self, obs, sample=False):
        """Select an action given an observation."""
        obs = torch.FloatTensor(obs).unsqueeze(0).to(self.device)  # Batch dimension
        with torch.no_grad():
            if sample:
                action = self.actor.get_action(obs)
            else:
                action = self.actor.get_det_action(obs)
        return action[0].numpy()

    def update_critic(self, obs, action, reward, next_obs, not_done, logger, step):
        """Update the critic network."""
        with torch.no_grad():
            next_action, log_prob = self.actor.evaluate(next_obs)
            target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
            target_V = torch.min(target_Q1, target_Q2) - self.alpha * log_prob
            target_Q = reward + (not_done * self.discount * target_V)

        current_Q1, current_Q2 = self.critic(obs, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

        # Log critic loss to both logger and wandb
        logger.log('train/critic_loss', critic_loss, step)
        wandb.log({'sac/critic_loss': critic_loss.item()}, step=step)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        self.critic.log(logger, step)

    def update_actor_and_alpha(self, obs, logger, step):
        """Update the actor and temperature parameter α."""
        action, log_prob = self.actor.evaluate(obs)
        actor_Q1, actor_Q2 = self.critic(obs, action)
        actor_Q = torch.min(actor_Q1, actor_Q2)

        actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()

        # Log actor loss to both logger and wandb
        logger.log('train/actor_loss', actor_loss, step)
        logger.log('train/target_entropy', self.target_entropy, step)
        logger.log('train/actor_entropy', -log_prob.mean(), step)
        wandb.log({'sac/actor_loss': actor_loss.item()}, step=step)
        wandb.log({'sac/target_entropy': self.target_entropy}, step=step)
        wandb.log({'sac/entropy': -log_prob.mean().item()}, step=step)

        # Optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.actor.log(logger, step)

        if self.learnable_temperature:
            alpha_loss = (self.alpha * (-log_prob - self.target_entropy).detach()).mean()
            logger.log('train/alpha_loss', alpha_loss, step)
            logger.log('train/alpha_value', self.alpha, step)
            wandb.log({'sac/alpha_loss': alpha_loss.item()}, step=step)
            wandb.log({'sac/alpha_value': self.alpha.item()}, step=step)

            self.log_alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.log_alpha_optimizer.step()

    def update(self, replay_buffer, logger, step):
        """Update SAC agent based on experience replay."""
        obs, action, reward, cost, next_obs, not_done, not_done_no_max = replay_buffer.sample(self.batch_size)

        # Log batch reward
        logger.log('train/batch_reward', reward.mean(), step)
        wandb.log({'sac/batch_reward': reward.mean().item()}, step=step)

        # Update the critic
        self.update_critic(obs, action, reward, next_obs, not_done_no_max, logger, step)

        # Periodically update the actor and temperature
        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(obs, logger, step)

        # Periodically update the critic target
        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)

    def save(self, path):
        if not os.path.exists(path):
            os.makedirs(path)

        torch.save(self.actor.trunk.state_dict(), os.path.join(path, "actor.pth"))
        torch.save(self.critic.Q1.state_dict(), os.path.join(path, "critic_q1.pth"))
        torch.save(self.critic.Q2.state_dict(), os.path.join(path, "critic_q2.pth"))
        pass

    def load(self, path):
        self.actor.trunk.load_state_dict(torch.load(os.path.join(path, "actor.pth")))
        self.critic.Q1.load_state_dict(torch.load(os.path.join(path, "critic_q1.pth")))
        self.critic.Q2.load_state_dict(torch.load(os.path.join(path, "critic_q2.pth")))
        pass
