from typing import NamedTuple, Tuple, Dict, Any
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from gops.algorithm.base import Algorithm
from gops.network.dacer import DACERNet, DACERParams
from gops.utils.experience import Experience
from gops.utils.typing import Metric
from gops.utils.gmm import fit_gmm_and_estimate_entropy_batch_torch
from gops.utils.util import fit_param

class DACER2OptStates:
    def __init__(self, q1, q2, policy, log_alpha):
        self.q1 = q1
        self.q2 = q2
        self.policy = policy
        self.log_alpha = log_alpha

class DACER2TrainState:
    def __init__(self, params, opt_state, step, mean_q1_std, mean_q2_std, entropy):
        self.params = params
        self.opt_state = opt_state
        self.step = step
        self.mean_q1_std = mean_q1_std
        self.mean_q2_std = mean_q2_std
        self.entropy = entropy

class DACER2(Algorithm):
    def __init__(
        self,
        agent: DACERNet,
        params: DACERParams,
        *,
        gamma: float = 0.99,
        score_lambda: float = 0.5,
        lr: float = 1e-4,
        alpha_lr: float = 3e-2,
        tau: float = 0.005,
        delay_alpha_update: int = 10000,
        delay_update: int = 2,
        reward_scale: float = 0.2,
        num_samples: int = 200,
        decay_steps: int = 50000,
        device: str = 'cuda'
    ):
        self.agent = agent
        self.gamma = gamma
        self.tau = tau
        self.delay_alpha_update = delay_alpha_update
        self.delay_update = delay_update
        self.reward_scale = reward_scale
        self.num_samples = num_samples
        self.score_lambda = score_lambda
        self.decay_steps = decay_steps
        self.device = device
        self.entropy = 0.0
        
        self.agent = self.agent.to(device)
        self.q1_optim = optim.Adam(params.q1.parameters(), lr=lr)
        self.q2_optim = optim.Adam(params.q2.parameters(), lr=lr)
        self.policy_optim = optim.Adam(params.policy.parameters(), lr=lr)
        self.alpha_optim = optim.Adam([params.log_alpha], lr=alpha_lr)
        
        self.state = DACER2TrainState(
            params=params,
            opt_state=DACER2OptStates(
                q1=self.q1_optim.state_dict(),
                q2=self.q2_optim.state_dict(),
                policy=self.policy_optim.state_dict(),
                log_alpha=self.alpha_optim.state_dict(),
            ),
            step=0,
            mean_q1_std=-1.0,
            mean_q2_std=-1.0,
            entropy=0.0,
        )
        
    def update(self, data: Experience) -> Tuple['DACER2TrainState', Metric]:
        obs, action, reward, next_obs, done = data.obs, data.action, data.reward, data.next_obs, data.done
        
        obs = torch.FloatTensor(obs).to(self.device)
        action = torch.FloatTensor(action).to(self.device)
        reward = torch.FloatTensor(reward).to(self.device)
        next_obs = torch.FloatTensor(next_obs).to(self.device)
        done = torch.FloatTensor(done).to(self.device)
        
        params = self.state.params
        step = self.state.step
        mean_q1_std = self.state.mean_q1_std
        mean_q2_std = self.state.mean_q2_std
        
        score_lambda = self.score_lambda * (
            1.0 - 0.9 * min(float(step) / self.decay_steps, 1.0)
        )
        
        reward *= self.reward_scale
        with torch.no_grad():
            next_action = self.agent.get_action(params.policy, params.log_alpha, params.q1, params.q2, next_obs)
            next_q1_mean, _, next_q1_sample = self.agent.q_evaluate(params.target_q1, next_obs, next_action)
            next_q2_mean, _, next_q2_sample = self.agent.q_evaluate(params.target_q2, next_obs, next_action)
            next_q_mean = torch.minimum(next_q1_mean, next_q2_mean)
            next_q_sample = torch.where(next_q1_mean < next_q2_mean, next_q1_sample, next_q2_sample)
            
            q_backup = reward + (1 - done) * self.gamma * next_q_mean
            q_backup_sample = reward + (1 - done) * self.gamma * next_q_sample
        
        def compute_q_loss(q_net, mean_q_std, q_backup, q_backup_sample):
            q_mean, q_std = q_net(obs, action)
            new_mean_q_std = torch.mean(q_std)
            
            if mean_q_std == -1.0:
                mean_q_std = new_mean_q_std.item()
            else:
                mean_q_std = self.tau * new_mean_q_std.item() + (1 - self.tau) * mean_q_std
            
            q_backup_bounded = q_mean + torch.clamp(q_backup_sample - q_mean, -3 * mean_q_std, 3 * mean_q_std)
            q_std_detach = torch.clamp(q_std.detach(), min=0)
            epsilon = 0.1
            
            q_loss = -(mean_q_std ** 2 + epsilon) * torch.mean(
                q_mean * (q_backup - q_mean).detach() / (q_std_detach ** 2 + epsilon) +
                q_std * ((q_mean.detach() - q_backup_bounded) ** 2 - q_std_detach ** 2) / (q_std_detach ** 3 + epsilon)
            )
            
            return q_loss, q_mean, q_std, mean_q_std
        
        q1_loss, q1_mean, q1_std, mean_q1_std = compute_q_loss(params.q1, mean_q1_std, q_backup, q_backup_sample)
        q2_loss, q2_mean, q2_std, mean_q2_std = compute_q_loss(params.q2, mean_q2_std, q_backup, q_backup_sample)
        
        self.q1_optim.zero_grad()
        q1_loss.backward()
        self.q1_optim.step()
        
        self.q2_optim.zero_grad()
        q2_loss.backward()
        self.q2_optim.step()
        
        if step % self.delay_alpha_update == 0:
            with torch.no_grad():
                actions_list = []
                for _ in range(self.num_samples):
                    sample_action = self.agent.get_action(params.policy, params.log_alpha.detach(), params.q1, params.q2, obs)
                    actions_list.append(sample_action)
                actions = torch.stack(actions_list, dim=1)
                
                entropy = fit_gmm_and_estimate_entropy_batch_torch(
                    actions_batch=actions,
                    n_components=3,
                    covariance_type='full',
                    n_init=1,
                    n_iter=100,
                    reg_covar=1e-6,
                    device=self.device
                )
                entropy = torch.mean(entropy).item()
            self.entropy = entropy
        else:
            entropy = self.entropy
                 
        if step % self.delay_update == 0:
            def compute_policy_loss():
                new_action = self.agent.get_action(params.policy, params.log_alpha, params.q1, params.q2, obs)
                q1_mean, _ = params.q1(obs, new_action)
                q2_mean, _ = params.q2(obs, new_action)
                q_mean = torch.minimum(q1_mean, q2_mean)
                return torch.mean(-q_mean)
            
            def compute_score_loss():
                t = torch.randint(0, self.agent.num_timesteps, (1,), device=self.device).item()
                noise = torch.randn_like(action)
                noisy_action = self.agent.diffusion.q_sample(t, action, noise)
                
                B = self.agent.diffusion.beta_schedule()
                eps_pred = self.agent.get_eps_pred(params.policy, params.log_alpha, params.q1, params.q2, t, noisy_action, obs)
                eps_pred_loss = eps_pred / B.sqrt_one_minus_alphas_cumprod[t]
                                 
                noisy_action_grad = noisy_action.detach().requires_grad_(True)
                critic_1_val = params.q1(obs, noisy_action_grad)[0].sum()
                critic_1_jacobian = torch.autograd.grad(critic_1_val, noisy_action_grad, create_graph=True)[0]
                
                noisy_action_grad = noisy_action.detach().requires_grad_(True)
                critic_2_val = params.q2(obs, noisy_action_grad)[0].sum()
                critic_2_jacobian = torch.autograd.grad(critic_2_val, noisy_action_grad, create_graph=True)[0]
                
                c, d = fit_param(self.agent.num_timesteps)
                
                step_sizes = torch.linspace(c, d, self.agent.num_timesteps, device=self.device)
                step_sizes = torch.exp(step_sizes)
                alpha_t = step_sizes[t]
                
                critic_jacobian = (critic_1_jacobian + critic_2_jacobian) / 2
                critic_jacobian_norm = torch.sqrt(torch.sum(critic_jacobian**2, dim=-1, keepdim=True)) + 1e-8
                critic_jacobian = critic_jacobian / critic_jacobian_norm
                
                return torch.mean((alpha_t * critic_jacobian.detach() + eps_pred_loss) ** 2)
                         
            policy_loss = compute_policy_loss()
            score_loss = compute_score_loss()
             
            self.policy_optim.zero_grad()
            policy_loss.backward(retain_graph=True)
            policy_grads = [p.grad.clone() if p.grad is not None else torch.zeros_like(p) for p in params.policy.parameters()]
             
            self.policy_optim.zero_grad()
            score_loss.backward()
            score_grads = [p.grad.clone() if p.grad is not None else torch.zeros_like(p) for p in params.policy.parameters()]
             
            self.policy_optim.zero_grad()
            for i, param in enumerate(params.policy.parameters()):
                param.grad = policy_grads[i] + score_grads[i] * score_lambda
             
            self.policy_optim.step()
            total_loss = policy_loss + score_loss * score_lambda
        else:
            total_loss = torch.tensor(0.0)
                 
        if step % self.delay_alpha_update == 0:
            log_alpha_loss = -torch.mean(params.log_alpha * (-entropy + self.agent.target_entropy))
            self.alpha_optim.zero_grad()
            log_alpha_loss.backward()
            self.alpha_optim.step()
         
        if step % self.delay_update == 0:
            for target_param, param in zip(params.target_q1.parameters(), params.q1.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
             
            for target_param, param in zip(params.target_q2.parameters(), params.q2.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        self.state = DACER2TrainState(
            params=params,
            opt_state=DACER2OptStates(
                q1=self.q1_optim.state_dict(),
                q2=self.q2_optim.state_dict(),
                policy=self.policy_optim.state_dict(),
                log_alpha=self.alpha_optim.state_dict(),
            ),
            step=step + 1,
            mean_q1_std=mean_q1_std,
            mean_q2_std=mean_q2_std,
            entropy=entropy,
        )
                 
        info = {
            "q1_loss": q1_loss.item(),
            "q1_mean": torch.mean(q1_mean).item(),
            "q1_std": torch.mean(q1_std).item(),
            "q2_loss": q2_loss.item(),
            "q2_mean": torch.mean(q2_mean).item(),
            "q2_std": torch.mean(q2_std).item(),
            "policy_loss": total_loss.item(),
            "alpha": torch.exp(params.log_alpha).item(),
            "mean_q1_std": mean_q1_std,
            "mean_q2_std": mean_q2_std,
            "entropy": entropy,
        }
        
        return self.state, info
        
    def get_policy_params(self):
        return (self.state.params.policy, self.state.params.log_alpha, self.state.params.q1, self.state.params.q2)
