import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import Dict, Union, Tuple
from offlinerlkit.policy import BasePolicy
from offlinerlkit.modules.rfqi_module import RFQIActor, RFQICritic, RFQIVAE, RFQIETA

class RFQIPolicy(BasePolicy):
    """
    Robust Fitted Q-Iteration (RFQI) - Fixed for Reproduction
    """

    def __init__(
        self,
        actor: RFQIActor,
        critic: RFQICritic,
        vae: RFQIVAE,
        actor_optim: torch.optim.Optimizer,
        critic_optim: torch.optim.Optimizer,
        vae_optim: torch.optim.Optimizer,
        state_dim: int,
        action_dim: int,
        tau: float = 0.005,
        gamma: float = 0.99,
        rho: float = 0.5,
        lmbda: float = 0.75,
        adam_lr: float = 3e-4,
        adam_eps: float = 1e-6,
        device: str = "cpu"
    ) -> None:
        super().__init__()
        
        self.actor = actor
        self.actor_target = copy.deepcopy(self.actor)
        self.critic = critic
        self.critic_target = copy.deepcopy(self.critic)
        self.vae = vae
        
        self.actor_optim = actor_optim
        self.critic_optim = critic_optim
        self.vae_optim = vae_optim
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        self._tau = tau
        self._gamma = gamma
        self.rho = rho
        self.lmbda = lmbda
        self.adam_lr = adam_lr
        self.adam_eps = adam_eps
        
        self.device = device
        self.eta_low = 0
        # Formula: 1 / (rho * (1 - gamma))
        self.eta_high = 1 / (self.rho * (1 - self._gamma))

        # [Correction 1] Removed persistent self.eta_net and self.eta_optim
        # They must be re-initialized every step to solve the specific batch optimization.

    def train(self) -> None:
        self.actor.train()
        self.critic.train()
        self.vae.train()

    def eval(self) -> None:
        self.actor.eval()
        self.critic.eval()
        self.vae.eval()

    def _sync_weight(self) -> None:
        for o, n in zip(self.actor_target.parameters(), self.actor.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
        for o, n in zip(self.critic_target.parameters(), self.critic.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)

    def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:
        with torch.no_grad():
            obs = torch.FloatTensor(obs).to(self.device)
            if obs.dim() == 1:
                obs = obs.unsqueeze(0)
            
            obs_rep = obs.repeat(100, 1)
            
            decoded_action = self.vae.decode(obs_rep)
            perturbed_action = self.actor(obs_rep, decoded_action)
            q1 = self.critic.q1(obs_rep, perturbed_action)
            ind = q1.argmax(0)
            action = perturbed_action[ind]
            
        return action.cpu().numpy()

    def optimize_eta(self, V_ns, s, a, tol=1e-3, max_iter=100): 
        # [Correction 1] Re-initialize Network and Optimizer strictly as per Original Implementation
        # Original uses max_iter=10000, we use 2000 here for a balance of speed/accuracy.
        # If performance is still low, increase to 10000.
        eta_func = RFQIETA(self.state_dim, self.action_dim).to(self.device)
        eta_optimizer = torch.optim.Adam(
            eta_func.parameters(), 
            lr=self.adam_lr,
            eps=self.adam_eps,
            maximize=True
        )
        
        def g(s, a, V_ns, rho, eta_func):
            eta = eta_func(s, a)
            # g(eta) = - max(eta - V, 0) + (1-rho)*eta
            g_val = -torch.maximum(eta - V_ns, s.new_tensor(0))
            g_val += (1 - rho) * eta
            g_val = g_val.sum()
            return g_val

        prev_gval = torch.tensor([float('inf')]).to(self.device)
        
        # Inner optimization loop
        # Note: Original code tracks norm(gval - prev_gval) < tol
        for i in range(max_iter):
            gval = g(s, a, V_ns, self.rho, eta_func)
            eta_optimizer.zero_grad()
            gval.backward()
            eta_optimizer.step()
            
            # Convergence check
            if i % 10 == 0: # Optimization: check every 10 steps to save compute
                loss = torch.norm(gval - prev_gval)
                if loss < tol:
                    break
                prev_gval = gval.detach()
            
        etas = eta_func(s, a).detach()
        etas = torch.clamp(etas, self.eta_low, self.eta_high).unsqueeze(dim=1)
        return etas

    def learn(self, batch: Dict) -> Dict:
        obss, actions, next_obss, rewards, terminals = \
            batch["observations"], batch["actions"], batch["next_observations"], batch["rewards"], batch["terminals"]
        batch_size = obss.shape[0]

        # 1. VAE Training
        recon, mean, std = self.vae(obss, actions)
        recon_loss = F.mse_loss(recon, actions)
        KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
        vae_loss = recon_loss + 0.5 * KL_loss

        self.vae_optim.zero_grad()
        vae_loss.backward()
        self.vae_optim.step()

        # 2. Critic Training
        with torch.no_grad():
            next_obss_rep = torch.repeat_interleave(next_obss, 10, 0)
            
            sampled_actions = self.vae.decode(next_obss_rep)
            perturbed_actions = self.actor_target(next_obss_rep, sampled_actions)
            
            target_Q1, target_Q2 = self.critic_target(next_obss_rep, perturbed_actions)
            target_Q = self.lmbda * torch.min(target_Q1, target_Q2) + (1. - self.lmbda) * torch.max(target_Q1, target_Q2)
            
            target_Q = target_Q.reshape(batch_size, -1).max(1)[0].reshape(-1, 1)

        # Optimize Eta (Reset per batch, high iter)
        # Note: This is slow. Original uses 10000, here we default to 2000 in method signature
        etas = self.optimize_eta(target_Q, obss, actions)

        # Robust Bellman Update
        with torch.no_grad():
            robust_term = -torch.maximum(etas - target_Q, etas.new_tensor(0)) + (1 - self.rho) * etas
            
            # [Correction 2] Handle Terminal Masking
            # Standard RL uses (1 - terminals). Original RFQI code snippet ignored it.
            # We keep standard RL practice here. If reproduction fails, try removing (1-terminals).
            q_target = rewards + self._gamma * (1 - terminals) * robust_term

        current_Q1, current_Q2 = self.critic(obss, actions)
        critic_loss = F.mse_loss(current_Q1, q_target) + F.mse_loss(current_Q2, q_target)

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        # 3. Actor Training
        sampled_actions_curr = self.vae.decode(obss)
        perturbed_actions_curr = self.actor(obss, sampled_actions_curr)

        actor_loss = -self.critic.q1(obss, perturbed_actions_curr).mean()

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        self._sync_weight()

        return {
            "loss/actor": actor_loss.item(),
            "loss/critic": critic_loss.item(),
            "loss/vae": vae_loss.item(),
            "metrics/max_eta": etas.max().item(),
            "metrics/mean_eta": etas.mean().item(),
        }