
import logging
from debug import debug_print
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Normal, Independent
from onpolicy.algorithms.diffusion_ac.common.mlp import ResidualMLP

log = logging.getLogger(__name__)

from onpolicy.algorithms.diffusion_ac.dppo.sampling import (
    extract,
    cosine_beta_schedule,
    make_timesteps,
)

from collections import namedtuple
Sample = namedtuple("Sample", "trajectories chains")


class MLP(nn.Module):

    def __init__(
        self,
        obs_dim,
        action_dim,
        # network,
        denoising_steps=20,
        predict_epsilon=True,
        network_path=None,
        device="cuda",
        horizon_steps=1,
        # Various clipping
        denoised_clip_value=1.0,
        randn_clip_value=10,
        final_action_clip_value=None,
        eps_clip_value=None,  # DDIM only
        # DDPM parameters
        # DDIM sampling
        use_ddim=False,
        ddim_discretize='uniform',
        ddim_steps=None,
        args=None,
        **kwargs,
    ):
        super().__init__()
        self.device = device
        self.horizon_steps = horizon_steps
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.denoising_steps = int(denoising_steps)
        self.predict_epsilon = predict_epsilon
        self.use_ddim = use_ddim
        self.ddim_steps = ddim_steps

        # Clip noise value at each denoising step
        self.denoised_clip_value = denoised_clip_value

        # Whether to clamp the final sampled action between [-1, 1]
        self.final_action_clip_value = final_action_clip_value

        # For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
        self.randn_clip_value = randn_clip_value

        # Clip epsilon for numerical stability
        self.eps_clip_value = eps_clip_value

        # Set up models
        # self.network = network.to(device)
        self.network = ResidualMLP(
            [self.obs_dim,] + [args.unet_hidden_size] * args.unet_num_layer + [self.action_dim * self.horizon_steps,],
            activation_type="ReLU",
            out_activation_type="Identity",
            use_layernorm=False,
            # residual_style=True,
        )
        # debug_print(self.predict_epsilon, self.use_ddim)
        self.args = args
        # self.register_buffer('eta', torch.zeros(1))
        # self.register_buffer('base_betas', torch.ones_like(self.betas) * 0.7)
        # self.register_buffer('target_betas', self.betas.clone())
        # self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - self.alphas_cumprod))
    
    def update_eta(self, v):
        pass

    # ---------- Sampling ----------#


    # @torch.no_grad()
    def p_sample_loop(self, cond, return_diffusion=False, return_noise=False, deterministic=False):
        """
        Forward pass for sampling actions. Used in evaluating pre-trained/fine-tuned policy. Not modifying diffusion clipping

        Args:
            cond: dict with key state/rgb; more recent obs at the end
                state: (B, To, Do)
                rgb: (B, To, C, H, W)
        Return:
            Sample: namedtuple with fields:
                trajectories: (B, Ta, Da)
        """
        device = self.device
        sample_data = cond
        B = len(sample_data)

        # Loop

        x = self.network(sample_data).reshape(B, self.horizon_steps, self.action_dim)
        
        # x = torch.randn((B, self.horizon_steps, self.action_dim), device=device)
        
        if return_diffusion: 
            diffusion = [x.reshape(B, -1)]
            log_probs = []
        
        if return_noise:
            noises = [x.reshape(B, -1)]
        
        if self.use_ddim:
            t_all = self.ddim_t
        else:
            t_all = list(reversed(range(self.denoising_steps)))
        for i, t in enumerate(t_all):
            t_b = make_timesteps(B, t, device)
            index_b = make_timesteps(B, i, device)
            mean = torch.zeros_like(x)
            logvar = torch.ones_like(x)
            std = torch.exp(0.5 * logvar)

            # Determine noise level
            if self.use_ddim:
                std = torch.zeros_like(std)
            else:
                if t == 0 and not self.args.use_latent_prob:
                    std = torch.zeros_like(std)
                else:
                    std = torch.clip(std, min=1e-3)
            noise = torch.randn_like(x).clamp_(
                -self.randn_clip_value, self.randn_clip_value
            )
            # x = mean + std * noise
            
            
            
            if return_diffusion: 
                diffusion.append(x.reshape(B, -1))
                if t >= 0 and self.args.use_latent_prob:
                    dist = Independent(Normal(mean.reshape(B, -1), std.reshape(B, -1)), 1)
                    log_probs.append(dist.log_prob(x.detach().reshape(B, -1)).unsqueeze(-1))
                else:
                    log_probs.append(torch.zeros(B, 1, device=x.device))
            
            if return_noise:
                if self.args.use_latent_prob:
                    noises.append(x.reshape(B, -1))
                else:
                    noises.append(noise.reshape(B, -1))
                # debug_print(noises[-1].shape)

            # clamp action at final step
            # if self.final_action_clip_value is not None and i == len(t_all) - 1:
            #     x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value)
        # print(x)
        # torch.set_printoptions(profile="full", precision=3, sci_mode=False)
        # print(x.shape)
        # print(x.reshape(x.shape[0], 8, -1))
        if return_noise:
            return x.reshape(B, -1), torch.stack(diffusion, dim=1), torch.stack(log_probs, dim=1), torch.stack(noises, dim=1)
        elif return_diffusion:
            return x.reshape(B, -1), torch.stack(diffusion, dim=1), torch.stack(log_probs, dim=1)
        else:
            return x.reshape(B, -1)
        
    def p_sample_loop_with_noise(self, cond, return_diffusion=False, return_noise=False, deterministic=False, noises=None, detach=False):
        """
        Forward pass for sampling actions. Used in evaluating pre-trained/fine-tuned policy. Not modifying diffusion clipping

        Args:
            cond: dict with key state/rgb; more recent obs at the end
                state: (B, To, Do)
                rgb: (B, To, C, H, W)
        Return:
            Sample: namedtuple with fields:
                trajectories: (B, Ta, Da)
        """
        device = self.device
        sample_data = cond
        B = len(sample_data)

        # Loop
        # x = torch.randn((B, self.horizon_steps, self.action_dim), device=device)
        # debug_print(noises.shape, B, cond.shape)
        x = self.network(sample_data).reshape(B, self.horizon_steps, self.action_dim)
        # x = noises[:, 0].reshape((B, self.horizon_steps, self.action_dim))
        
        if return_diffusion: 
            diffusion = [x.reshape(B, -1)]
            log_probs = []
        
        if return_noise:
            noises = [x.reshape(B, -1)]
        
        if self.use_ddim:
            t_all = self.ddim_t
        else:
            t_all = list(reversed(range(self.denoising_steps)))
        # torch.set_printoptions(profile="full", precision=5, sci_mode=False)
        for i, t in enumerate(t_all):
            if detach:
                x = x.detach()
            t_b = make_timesteps(B, t, device)
            index_b = make_timesteps(B, i, device)
            mean = torch.zeros_like(x)
            logvar = torch.ones_like(x)
            std = torch.exp(0.5 * logvar)
            # mean.register_hook(lambda x: print(x.norm()))

            # Determine noise level
            if self.use_ddim:
                std = torch.zeros_like(std)
            else:
                if t == 0 and not self.args.use_latent_prob:
                    std = torch.zeros_like(std)
                else:
                    std = torch.clip(std, min=1e-3)
            # noise = torch.randn_like(x).clamp_(
            #     -self.randn_clip_value, self.randn_clip_value
            # )
            noise = noises[:, i+1].reshape(x.shape)
            # x = mean + std * noise
            
            
            if return_diffusion: 
                diffusion.append(x.reshape(B, -1))
                if t >= 0 and self.args.use_latent_prob:
                    dist = Independent(Normal(mean.reshape(B, -1), std.reshape(B, -1)), 1)
                    # debug_print(dist.log_prob(noise).shape, mean.shape, std.shape, noise.shape)
                    # x = noise
                    # debug_print(dist.log_prob(noise).shape, mean.shape, std.shape, noise.shape)
                    log_probs.append(dist.log_prob(x.reshape(B, -1)).unsqueeze(-1))
                else:
                    log_probs.append(torch.zeros(B, 1, device=x.device))
            
            if return_noise:
                noises.append(noise.reshape(B, -1))

            # clamp action at final step
            # if self.final_action_clip_value is not None and i == len(t_all) - 1:
            #     x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value)
        if return_noise:
            return x.reshape(B, -1), torch.stack(diffusion, dim=1), torch.stack(log_probs, dim=1), torch.stack(noises, dim=1)
        elif return_diffusion:
            return x.reshape(B, -1), torch.stack(diffusion, dim=1), torch.stack(log_probs, dim=1)
        else:
            return x.reshape(B, -1)

    # ---------- Supervised training ----------#
