import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.distributions import Normal, MultivariateNormal 
import os
import math
from typing import Optional

# from utils.util import unpack_batch, RunningMeanStd
from utils.util import unpack_batch
from agent.sac.sac_agent import SACAgent
from agent.sac.actor import DiagGaussianActor
from agent.rffsac.rffsac_agent import RFFSACAgent

from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class RFFSACAgentBonus(RFFSACAgent):
    """RFF-SAC agent with an intrinsic exploration bonus following CTRL-UCB."""

    def __init__(
        self,
        state_dim,
        action_dim,
        action_space,
        bonus_coef: float = 0.0,
        bonus_clip: Optional[float] = 1.0,
        bonus_lambda: float = 1.0,
        **kwargs,
    ):
        super().__init__(
            state_dim=state_dim,
            action_dim=action_dim,
            action_space=action_space,
            **kwargs,
        )
        assert bonus_lambda > 0.0, "bonus_lambda must be positive."
        assert bonus_coef >= 0.0, "bonus_coef must be non-negative."

        self.bonus_coef = bonus_coef
        self.bonus_clip = bonus_clip
        self.bonus_lambda = bonus_lambda
        self._bonus_eps = 1e-6

        self.precision_matrix = (1.0 / self.bonus_lambda) * torch.eye(self.feature_dim, device=device)
        self._precision_update_step = 0
        self._precision_metrics = {
            'step': [],
            'diag_min': [],
            'diag_max': [],
            'fro_norm': [],
        }
        self.use_bonus = True
        self._negative_quad_count = 0

        self.flag = True  # For debugging purpose only

    def compute_bonus(self, z_phi):
        features = z_phi.detach() # batch x feature_dim
        if not self.use_bonus:
            if features.dim() == 1:
                features = features.unsqueeze(0)
            return torch.zeros(
                (features.shape[0], 1),
                device=features.device,
                dtype=features.dtype,
            )
        with torch.no_grad():
            precision_projection = torch.matmul(features, self.precision_matrix.t()) # batch x feature_dim
            quad = (features * precision_projection).sum(dim=1, keepdim=True) # batch x 1
            quad = torch.clamp(quad, min=self._bonus_eps)
            bonus_tensor = torch.sqrt(quad)
            bonus_tensor = self._scale_bonus(bonus_tensor)

        return bonus_tensor

    def update_precision_matrix(self, state, action, z_phi=None):
        if z_phi is None:
            state_tensor = torch.as_tensor(state, device=self.device, dtype=self.precision_matrix.dtype)
            action_tensor = torch.as_tensor(action, device=self.device, dtype=self.precision_matrix.dtype)
            with torch.no_grad():
                z_phi = self.frozen_phi(state_tensor, action_tensor)
            if self.flag:
                print(z_phi.shape)
                self.flag = False

        column = z_phi.detach()
        precision = self.precision_matrix
        with torch.no_grad():
            precision_column = precision @ column
            quad = torch.dot(column, precision_column)
            # Sherman-Morrison rank-one update of the precision matrix
            rank_update = torch.outer(precision_column, precision_column) / (1.0 + quad)
            precision = precision - rank_update
        
        quad_value = float(quad.item())
        if quad_value < 0.0:
            self._negative_quad_count += 1
            print("Warning: negative quad encountered in precision matrix update.")
            if self._negative_quad_count >= 100 * self.feature_dim and self.use_bonus:
                self.use_bonus = False
                print("[RFFSACBonus] Disabling bonus usage due to repeated negative quad events.")

        self.precision_matrix = precision
        self._log_precision_metrics()

    def _log_precision_metrics(self):
        self._precision_update_step += 1
        diag = torch.diag(self.precision_matrix)
        diag_cpu = diag.detach().cpu()
        fro = torch.linalg.norm(self.precision_matrix).detach().cpu()
        self._precision_metrics['step'].append(self._precision_update_step)
        self._precision_metrics['diag_min'].append(float(diag_cpu.min().item()))
        self._precision_metrics['diag_max'].append(float(diag_cpu.max().item()))
        self._precision_metrics['fro_norm'].append(float(fro.item()))

    # def save_precision_metrics_plot(self, filepath):
    #     if not self._precision_metrics['step']:
    #         return
    #     import matplotlib.pyplot as plt

    #     steps = self._precision_metrics['step']
    #     diag_min = self._precision_metrics['diag_min']
    #     diag_max = self._precision_metrics['diag_max']
    #     fro_norm = self._precision_metrics['fro_norm']

    #     plt.figure(figsize=(10, 6))
    #     plt.plot(steps, diag_min, label='Diag min')
    #     plt.plot(steps, diag_max, label='Diag max')
    #     plt.plot(steps, fro_norm, label='Frobenius norm')
    #     plt.xlabel('Precision update step')
    #     plt.ylabel('Metric value')
    #     plt.title('Precision matrix diagnostics')
    #     plt.legend()
    #     plt.grid(True, alpha=0.3)
    #     plt.tight_layout()
    #     plt.savefig(filepath)
    #     plt.close()

    def _scale_bonus(self, bonus):
        scaled_bonus = self.bonus_coef * bonus
        if self.bonus_clip is not None:
            scaled_bonus = torch.clamp(scaled_bonus, min=0.0, max=self.bonus_clip)
        return scaled_bonus

    def critic_step(self, batch):
        state, action, next_state, reward, done = unpack_batch(batch)

        with torch.no_grad():
            dist = self.actor(next_state)
            next_action = dist.rsample()
            next_action_log_pi = dist.log_prob(next_action).sum(-1, keepdim=True)

            if self.use_feature_target:
                z_phi = self.frozen_phi_target(state, action)
                z_phi_next = self.frozen_phi_target(next_state, next_action)
            else:
                z_phi = self.frozen_phi(state, action)
                z_phi_next = self.frozen_phi(next_state, next_action)

            next_q1, next_q2 = self.critic_target(z_phi_next, next_state, next_action)
            next_q = torch.min(next_q1, next_q2) - self.alpha * next_action_log_pi

            bonus = self.compute_bonus(z_phi)
            bonus = bonus.reshape_as(reward)
            augmented_reward = reward + bonus

            target_q = augmented_reward + (1.0 - done) * self.discount * next_q

        q1, q2 = self.critic(z_phi, state, action)
        q1_loss = torch.nn.functional.mse_loss(target_q, q1)
        q2_loss = torch.nn.functional.mse_loss(target_q, q2)
        q_loss = q1_loss + q2_loss

        self.critic_optimizer.zero_grad()
        q_loss.backward()
        self.critic_optimizer.step()

        return {
            'q1_loss': q1_loss.item(),
            'q2_loss': q2_loss.item(),
            'q1': q1.mean().item(),
            'q2': q2.mean().item(),
            'bonus_mean': bonus.mean().item(),
            'bonus_abs_mean': bonus.abs().mean().item(),
        }

 