from dataclasses import dataclass
from typing import Callable, NamedTuple, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from gops.network.blocks import Activation, DistributionalQNet2, DACER2PolicyNet
from gops.utils.diffusion import GaussianDiffusion
from gops.utils.torch_utils import random_key_from_data

class DACER2Params:
    def __init__(self, q1, q2, target_q1, target_q2, policy, log_alpha):
        self.q1 = q1
        self.q2 = q2
        self.target_q1 = target_q1
        self.target_q2 = target_q2
        self.policy = policy
        self.log_alpha = log_alpha

@dataclass
class DACER2Net:
    q1: nn.Module
    q2: nn.Module
    target_q1: nn.Module
    target_q2: nn.Module
    policy: nn.Module
    num_timesteps: int
    act_dim: int
    target_entropy: float
    device: str = 'cuda'
    
    def __post_init__(self):
        self._diffusion = GaussianDiffusion(self.num_timesteps, device=self.device)
        
        self.q1 = self.q1.to(self.device)
        self.q2 = self.q2.to(self.device)
        self.target_q1 = self.target_q1.to(self.device)
        self.target_q2 = self.target_q2.to(self.device)
        self.policy = self.policy.to(self.device)
    
    @property
    def diffusion(self) -> GaussianDiffusion:
        return self._diffusion
    
    def get_eps_pred(self, policy_net, log_alpha, q1_net, q2_net, t, x, obs):
        return policy_net(obs, x, t)
    
    def get_action(self, policy_net, log_alpha, q1_net, q2_net, obs):
        with torch.no_grad():
            def model_fn(t, x):
                t_tensor = torch.full((x.shape[0],), t, device=self.device, dtype=torch.long)
                return policy_net(obs, x, t_tensor)
            
            batch_size = obs.shape[0]
            action_shape = (batch_size, self.act_dim)
            action = self.diffusion.p_sample(model_fn, action_shape)
            
            noise = torch.randn_like(action)
            action = action + noise * torch.exp(log_alpha) * 0.1
            
            action = torch.clamp(action, -1, 1)
            
        return action
    
    def get_deterministic_action(self, policy_net, log_alpha, q1_net, q2_net, obs):
        log_alpha_det = torch.full_like(log_alpha, -float('inf'))
        return self.get_action(policy_net, log_alpha_det, q1_net, q2_net, obs)
    
    def q_evaluate(self, q_net, obs, act):
        q_mean, q_std = q_net(obs, act)
        
        z = torch.randn_like(q_mean)
        z = torch.clamp(z, -3.0, 3.0)
        
        q_value = q_mean + q_std * z
        
        return q_mean, q_std, q_value

def create_dacer2_net(
    obs_dim: int,
    act_dim: int,
    hidden_sizes: Sequence[int],
    diffusion_hidden_sizes: Sequence[int],
    activation: str = 'relu',
    num_timesteps: int = 20,
    device: str = 'cuda'
) -> Tuple[DACER2Net, DACER2Params]:
    
    q1_net = DistributionalQNet2(obs_dim + act_dim, hidden_sizes, activation).to(device)
    q2_net = DistributionalQNet2(obs_dim + act_dim, hidden_sizes, activation).to(device)
    
    import copy
    target_q1_net = copy.deepcopy(q1_net)
    target_q2_net = copy.deepcopy(q2_net)
    
    policy_net = DACER2PolicyNet(
        obs_dim, 
        act_dim, 
        diffusion_hidden_sizes, 
        activation, 
        num_timesteps
    ).to(device)
    
    log_alpha = nn.Parameter(torch.tensor(1.0, device=device))
    
    params = DACER2Params(
        q1=q1_net,
        q2=q2_net,
        target_q1=target_q1_net,
        target_q2=target_q2_net,
        policy=policy_net,
        log_alpha=log_alpha
    )
    
    net = DACER2Net(
        q1=q1_net,
        q2=q2_net,
        target_q1=target_q1_net,
        target_q2=target_q2_net,
        policy=policy_net,
        num_timesteps=num_timesteps,
        act_dim=act_dim,
        target_entropy=-act_dim,
        device=device
    )
    
    return net, params
