import torch
import torch.nn.functional as F
from torch.distributions import Normal
from itertools import chain
import numpy as np
import wandb
import os
import hydra

import utils
from algorithm.sac import SACAgent


class GMMSSACAgent(SACAgent):
    def __init__(self, obs_dim, action_dim, action_range, device, critic_cfg, safety_critic_cfg, actor_cfg, discount,
                 init_temperature, alpha_lr, alpha_betas, actor_lr, actor_betas, actor_update_frequency, critic_lr,
                 critic_betas, critic_tau, critic_target_update_frequency, batch_size, learnable_temperature,
                 cost_limit, max_episode_len, risk_level, damp_scale, lr_scale):
        super().__init__(obs_dim, action_dim, action_range, device, critic_cfg, actor_cfg, discount, init_temperature,
                         alpha_lr, alpha_betas, actor_lr, actor_betas, actor_update_frequency, critic_lr, critic_betas,
                         critic_tau, critic_target_update_frequency, batch_size, learnable_temperature)

        self.safety_critic = hydra.utils.instantiate(safety_critic_cfg).to(self.device)
        self.safety_critic_target = hydra.utils.instantiate(safety_critic_cfg).to(self.device)
        self.safety_critic_target.load_state_dict(self.safety_critic.state_dict())

        self.max_episode_len = max_episode_len
        self.cost_limit = cost_limit
        self.risk_level = risk_level
        self.damp_scale = damp_scale
        self.cost_lr_scale = lr_scale

        self.log_beta = torch.tensor(np.log(np.clip(init_temperature, 1e-8, 1e8)), device=self.device)
        self.log_beta.requires_grad = True

        self.target_cost = (
                self.cost_limit * (1 - self.discount ** self.max_episode_len) / (
                1 - self.discount) / self.max_episode_len
        )

        self.all_critics_optimizer = torch.optim.Adam(
            chain(self.critic.parameters(), self.safety_critic.parameters()),
            lr=critic_lr,
            betas=critic_betas,
        )
        self.log_beta_optimizer = torch.optim.Adam([self.log_beta], lr=alpha_lr * self.cost_lr_scale, betas=alpha_betas)

        self.train()
        self.critic_target.train()
        self.safety_critic_target.train()

    def train(self, training=True):
        """Set training mode for networks."""
        super().train(training)
        self.safety_critic.train(training)

    @property
    def beta(self):
        return self.log_beta.exp()

    def sample_gmm(self, means, variances, weights, num_samples=500):
        batch_size, num_components = weights.shape
        num_samples_component = (weights * num_samples).long()  # (batch_size, n_components)
        num_samples_component[:, -1] += num_samples - num_samples_component.sum(dim=1)
        current_samples = torch.zeros(batch_size, num_samples, device=means.device)  # (batch_size, num_samples)
        for k in range(num_components):
            dist = Normal(means[:, k], torch.sqrt(torch.clamp(variances[:, k], min=1e-6)))  # (batch_size,)
            samples_k = dist.sample(
                (num_samples_component[:, k].max().item(),))  # (max_samples_per_component, batch_size)

            for b in range(batch_size):
                num_k = num_samples_component[b, k].item()
                start_idx = current_samples[b].nonzero().size(0)
                current_samples[b, start_idx:start_idx + num_k] = samples_k[:num_k, b]

        samples = current_samples[
            torch.arange(batch_size, device=means.device).unsqueeze(1), torch.randperm(num_samples,
                                                                                       device=means.device).repeat(
                batch_size, 1)]
        return samples

    def bellman_transform(self, next_samples, cost, not_done):
        return cost + self.discount * next_samples * not_done

    def incremental_refinement(self, current_samples, bellman_samples, beta=0.6):
        batch_size = current_samples.size(0)
        num_samples = current_samples.size(1)

        # Calculate number of Bellman and current samples to select based on beta
        num_bellman_samples = int(beta * num_samples)  # Number of Bellman samples
        num_current_samples = num_samples - num_bellman_samples  # Remaining for current samples

        # Sample indices for current samples and Bellman samples
        current_indices = torch.randperm(num_samples, device=current_samples.device)[:num_current_samples]
        bellman_indices = torch.randperm(num_samples, device=current_samples.device)[:num_bellman_samples]

        # Sample from current samples and Bellman samples based on the indices
        current_samples_selected = current_samples[:, current_indices]
        bellman_samples_selected = bellman_samples[:, bellman_indices]

        # Combine the selected samples to form the updated samples
        updated_samples = torch.cat([current_samples_selected, bellman_samples_selected], dim=1)

        return updated_samples

    def em_update(self, current_means, current_variances, current_weights, target_samples, prob_clip_value=1e-8):
        batch_size, gaussian_num = current_means.shape
        target_sample_num = target_samples.shape[1]
        current_stds = torch.sqrt(current_variances)

        # E
        current_means = current_means.unsqueeze(-1)  # (batch_size, gaussian_num, 1)
        current_variances = current_variances.unsqueeze(-1)  # (batch_size, gaussian_num, 1)
        current_stds = current_stds.unsqueeze(-1)  # (batch_size, gaussian_num, 1)
        current_weights = current_weights.unsqueeze(-1)  # (batch_size, gaussian_num, 1)
        target_samples = target_samples.unsqueeze(1)  # (batch_size, 1, target_sample_num)

        pi = torch.tensor(np.pi, device=current_means.device)
        prob_list = current_weights * (1 / (torch.sqrt(2 * pi) * current_stds)) * \
                    torch.exp(-(target_samples - current_means) ** 2 / (
                            2 * current_variances))  # (batch_size, gaussian_num, target_sample_num)

        prob_list = torch.clamp(prob_list, min=prob_clip_value, max=3.0)

        total_weight = prob_list.sum(dim=1, keepdim=True)  # (batch_size, 1, target_sample_num)
        weight_list = prob_list / total_weight  # (batch_size, gaussian_num, target_sample_num)

        random_coef_list = torch.rand((gaussian_num, batch_size, target_sample_num),
                                      device=current_means.device)  # (gaussian_num, batch_size, target_sample_num)
        random_coef_list = random_coef_list.permute(1, 0, 2)  # (batch_size, gaussian_num, target_sample_num)
        weight_list = (weight_list > random_coef_list).float()  # (batch_size, gaussian_num, target_sample_num)

        # M
        weight_sum_list = torch.clamp(weight_list.sum(dim=2), min=1,
                                      max=target_sample_num)  # (batch_size, gaussian_num)
        updated_means = (weight_list * target_samples).sum(dim=2) / weight_sum_list  # (batch_size, gaussian_num)

        updated_variances = (weight_list * (target_samples - updated_means.unsqueeze(2)) ** 2).sum(
            dim=2) / weight_sum_list  # (batch_size, gaussian_num)
        # updated_stds = torch.sqrt(updated_variances)  # (batch_size, gaussian_num)

        updated_weights = weight_sum_list / target_sample_num  # (batch_size, gaussian_num)

        sorted_indices = torch.argsort(updated_means, dim=1)  # (batch_size, gaussian_num)
        batch_indices = torch.arange(batch_size, device=current_means.device).unsqueeze(-1).expand(-1,
                                                                                                   gaussian_num)  # (batch_size, gaussian_num)

        updated_means = updated_means[batch_indices, sorted_indices]  # (batch_size, gaussian_num)
        updated_variances = updated_variances[batch_indices, sorted_indices]  # (batch_size, gaussian_num)
        updated_weights = updated_weights[batch_indices, sorted_indices]  # (batch_size, gaussian_num)

        return updated_means, updated_variances, updated_weights

    def calculate_cvar(self, mu, sigma, risk_level):
        """Calculate the CVaR for a given Gaussian component."""
        normal_dist = Normal(mu, sigma)
        norm_inv_risk_level = normal_dist.icdf(torch.tensor(risk_level, device=mu.device))
        cvar = mu + sigma * (normal_dist.log_prob(norm_inv_risk_level).exp() / risk_level)
        return cvar

    def calculate_scvar(self, means, variances, weights, risk_level=0.95):
        """
        Calculate SCVaR from GMM parameters (means, variances, and weights).
        """
        variances = torch.clamp(variances, min=1e-8, max=1e8).to(means.device)
        cvars = torch.stack([self.calculate_cvar(mu, sigma, risk_level) for mu, sigma in zip(means, variances)])
        scvar = torch.max(cvars)
        return scvar

    def update_critic(self, obs, action, reward, cost, next_obs, not_done):
        """Update both the reward critic and the safety critic."""
        with torch.no_grad():
            next_action, log_prob = self.actor.evaluate(next_obs)
            target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
            target_V = torch.min(target_Q1, target_Q2) - self.alpha * log_prob
            target_Q = reward + (not_done * self.discount * target_V)

        current_Q1, current_Q2 = self.critic(obs, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

        wandb.log({'sac/critic_loss': critic_loss.item()})
        current_means, current_variances, current_weights = self.safety_critic(obs, action)
        current_variances = torch.clamp(current_variances, min=1e-8, max=1e8)

        next_means, next_variances, next_weights = self.safety_critic_target(next_obs, next_action)
        next_variances = torch.clamp(next_variances, min=1e-8, max=1e8)

        current_samples = self.sample_gmm(current_means.detach(), current_variances.detach(), current_weights.detach())
        next_samples = self.sample_gmm(next_means.detach(), next_variances.detach(), next_weights.detach())

        bellman_samples = self.bellman_transform(next_samples, cost, not_done)

        updated_samples = self.incremental_refinement(current_samples, bellman_samples, beta=0.6)

        updated_means, updated_variances, updated_weights = self.em_update(
            current_means, current_variances, current_weights, updated_samples
        )
        updated_variances = torch.clamp(updated_variances, min=1e-8, max=1e8)

        mu_loss = F.mse_loss(current_means, updated_means)
        sigma_loss = F.mse_loss(torch.sqrt(current_variances), torch.sqrt(updated_variances))
        omega_loss = F.mse_loss(current_weights, updated_weights)

        safety_critic_loss = mu_loss + sigma_loss + omega_loss
        wandb.log({'sac/safety_critic_loss': safety_critic_loss.item()})

        total_loss = critic_loss + safety_critic_loss
        self.all_critics_optimizer.zero_grad()
        total_loss.backward()
        self.all_critics_optimizer.step()

    def update_actor_and_alpha(self, obs, action):
        """Update the actor and temperature parameter α with cost constraint."""
        action_pi, log_prob = self.actor.evaluate(obs)
        actor_Q1, actor_Q2 = self.critic(obs, action_pi)
        actor_Q = torch.min(actor_Q1, actor_Q2)

        actor_means, actor_variances, actor_weights = self.safety_critic(obs, action_pi)
        actor_variances = torch.clamp(actor_variances, min=1e-8, max=1e8)

        current_means, current_variances, current_weights = self.safety_critic(obs, action)
        current_variances = torch.clamp(current_variances, min=1e-8, max=1e8)

        current_scvar = self.calculate_scvar(current_means, current_variances, current_weights)
        damp = self.damp_scale * torch.mean(self.target_cost - current_scvar)

        actor_scvar = self.calculate_scvar(actor_means, actor_variances, actor_weights)
        actor_loss = torch.mean(
            self.alpha.detach() * log_prob
            - actor_Q
            + (self.beta.detach() - damp) * actor_scvar
        )

        wandb.log({'sac/actor_loss': actor_loss.item()})
        wandb.log({'sac/actor_cost': actor_scvar.item()})

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        if self.learnable_temperature:
            self.log_alpha_optimizer.zero_grad()
            alpha_loss = (self.alpha * (-log_prob - self.target_entropy).detach()).mean()
            wandb.log({'sac/alpha_loss': alpha_loss.item()})
            wandb.log({'sac/alpha_value': self.alpha.item()})
            alpha_loss.backward()
            self.log_alpha_optimizer.step()

            self.log_beta_optimizer.zero_grad()
            beta_loss = (self.beta * (self.target_cost - current_scvar).detach()).mean()
            wandb.log({'sac/beta_loss': beta_loss.item()})
            wandb.log({'sac/beta_value': self.beta.item()})
            beta_loss.backward()
            self.log_beta_optimizer.step()

    def update(self, replay_buffer, step):
        """Update SAC agent based on experience replay, including cost."""
        obs, action, reward, cost, next_obs, not_done, not_done_no_max = replay_buffer.sample(self.batch_size)

        wandb.log({'sac/batch_reward': reward.mean().item()})
        wandb.log({'sac/batch_cost': cost.mean().item()})

        self.update_critic(obs, action, reward, cost, next_obs, not_done_no_max)
        if step % self.actor_update_frequency == 0:
            self.update_actor_and_alpha(obs, action)

        if step % self.critic_target_update_frequency == 0:
            utils.soft_update_params(self.critic, self.critic_target, self.critic_tau)
            utils.soft_update_params(self.safety_critic, self.safety_critic_target, self.critic_tau)

    def save(self, path):
        """Save model parameters."""
        super().save(path)
        torch.save(
            self.safety_critic.QCs.state_dict(),
            os.path.join(path, "safety_critic_qc.pth"),
        )
        torch.save(
            self.safety_critic.VCs.state_dict(),
            os.path.join(path, "safety_critic_vc.pth"),
        )
        torch.save(
            self.safety_critic.Weights.state_dict(),
            os.path.join(path, "safety_critic_weights.pth"),
        )

    def load(self, path):
        """Load model parameters."""
        super().load(path)
        self.safety_critic.QCs.load_state_dict(torch.load(os.path.join(path, "safety_critic_qc.pth")))
        self.safety_critic.VCs.load_state_dict(torch.load(os.path.join(path, "safety_critic_vc.pth")))
        self.safety_critic.Weights.load_state_dict(torch.load(os.path.join(path, "safety_critic_weights.pth")))
