from dataclasses import dataclass
from typing import Optional

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# from data import Batch
from .inner_model import InnerModel, InnerModelConfig, StateInnerModelConfig
from ..layers import mlp, SimNorm
from ..perceiver import PerceiverConfig
from utils import LossAndLogs
from einops import rearrange, repeat


def add_dims(input: Tensor, n: int) -> Tensor:
    return input.reshape(input.shape + (1,) * (n - input.ndim))


@dataclass
class Conditioners:
    c_in: Tensor
    c_out: Tensor
    c_skip: Tensor
    c_noise: Tensor


@dataclass
class SigmaDistributionConfig:
    loc: float
    scale: float
    sigma_min: float
    sigma_max: float


@dataclass
class DenoiserConfig:
    inner_model: StateInnerModelConfig
    perceiver: PerceiverConfig
    sigma_data: float
    sigma_offset_noise: float


def zero_(params):
	"""Initialize parameters to zero."""
	for p in params:
		p.data.fill_(0)

class Denoiser(nn.Module):
    def __init__(self, cfg: DenoiserConfig,
                 num_agents: int = None,
                 clip_denoised: bool = False,
                 is_continuous_act: bool = False,) -> None:
        super().__init__()
        self.cfg = cfg
        # take global state as input
        self.inner_model = InnerModel(cfg.inner_model, cfg.perceiver, num_agents=num_agents, is_continuous_act=is_continuous_act)

        # ----------------------------------

        self.sample_sigma_training = None

        self.is_continuous_act = is_continuous_act
        self.num_agents = num_agents
        self.clip_denoised = clip_denoised

    @property
    def device(self) -> torch.device:
        return self.inner_model.noise_emb.weight.device
    
    ## related to latent encoder, but we believe it is generally useful to raw data
    # suppose that the latent generated by the encoder has been normalized feature-wise
    # The following mimics NVIDIA Cosmos
    @torch.no_grad()
    def encode(self, latent):
        return latent * self.cfg.sigma_data
    
    @torch.no_grad()
    def decode(self, latent):
        return latent / self.cfg.sigma_data

    def setup_training(self, cfg: SigmaDistributionConfig) -> None:
        assert self.sample_sigma_training is None

        def sample_sigma(n: int, device: torch.device):
            s = torch.randn(n, device=device) * cfg.scale + cfg.loc
            return s.exp().clip(cfg.sigma_min, cfg.sigma_max)

        self.sample_sigma_training = sample_sigma
    
    def apply_noise(self, x: Tensor, sigma: Tensor, sigma_offset_noise: float) -> Tensor:
        b, _, _ = x.shape 
        
        ## The below is adding offset noise in the image-based case
        # b, c, _, _ = x.shape 
        # offset_noise = sigma_offset_noise * torch.randn(b, c, 1, 1, device=self.device)

        ## for the state-based case
        # offset_noise = sigma_offset_noise * torch.randn(b, 1, 1, device=self.device)
        offset_noise = sigma_offset_noise * torch.randn_like(x, device=self.device)

        return x + offset_noise + torch.randn_like(x) * add_dims(sigma, x.ndim)

    def compute_conditioners(self, sigma: Tensor) -> Conditioners:
        sigma = (sigma**2 + self.cfg.sigma_offset_noise**2).sqrt()
        c_in = 1 / (sigma**2 + self.cfg.sigma_data**2).sqrt()
        c_skip = self.cfg.sigma_data**2 / (sigma**2 + self.cfg.sigma_data**2)
        c_out = sigma * c_skip.sqrt()
        c_noise = sigma.log() / 4
        return Conditioners(*(add_dims(c, n) for c, n in zip((c_in, c_out, c_skip, c_noise), (3, 3, 3, 1, 1))))

    def compute_model_output(self, noisy_next_obs: Tensor, obs: Tensor, act: Tensor, cs: Conditioners, act_mask: Tensor) -> Tensor:
        # assert act.size(2) == self.num_agents        
        # rescaled_obs = obs / self.cfg.sigma_data
        # rescaled_noise = noisy_next_obs * cs.c_in
        # return self.inner_model(rescaled_noise, cs.c_noise, rescaled_obs, act, act_mask)

        # Here is our adapted implementation on latent diffusion model
        # That is to say, obs would be a latent
        assert act.size(2) == self.num_agents
        
        # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf
        innermodel_input = noisy_next_obs * cs.c_in

        # output of F_theta in Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf
        return self.inner_model(innermodel_input, cs.c_noise, obs, act, act_mask)
    
    @torch.no_grad()
    def wrap_model_output(self, noisy_next_obs: Tensor, model_output: Tensor, cs: Conditioners) -> Tensor:
        d = cs.c_skip * noisy_next_obs + cs.c_out * model_output
        # Quantize to {0, ..., 255}, then back to [-1, 1] -> This is for image-based case
        # d = d.clamp(-1, 1).add(1).div(2).mul(255).byte().div(255).mul(2).sub(1)
        
        if self.clip_denoised:
            d = d.clamp(-1., 1.)
        
        return d
    
    @torch.no_grad()
    def denoise(self, noisy_next_obs: Tensor, sigma: Tensor, obs: Tensor, act: Tensor, act_mask: Tensor) -> Tensor:
        cs = self.compute_conditioners(sigma)
        model_output = self.compute_model_output(noisy_next_obs, obs, act, cs, act_mask)
        denoised = self.wrap_model_output(noisy_next_obs, model_output, cs)
        return denoised

    # def forward(self, batch: Batch) -> LossAndLogs:
    def forward(self, batch) -> LossAndLogs:
        assert batch.act.size(2) == self.num_agents

        n = self.cfg.inner_model.num_steps_conditioning
        seq_length = batch.shared_obs.size(1) - n

        all_obs = batch.shared_obs.clone() # using global state, eliminating the agent axis
        loss = 0
        # import ipdb; ipdb.set_trace()
        for i in range(seq_length):
            obs = all_obs[:, i : n + i]             # (b, seq_length, state_dim)
            next_obs = all_obs[:, n + i]            # (b, state_dim)
            act = batch.act[:, i : n + i]           # (b, seq_length, n, act_dim)
            mask = batch.mask_padding[:, n + i]     # (b, seq_length)

            if not self.is_continuous_act:
                act = act.argmax(-1)

            b, t, d = obs.shape
            next_obs = next_obs.unsqueeze(1)
            # obs = obs.reshape(b, t * c, h, w)

            sigma = self.sample_sigma_training(b, self.device)
            noisy_next_obs = self.apply_noise(next_obs, sigma, self.cfg.sigma_offset_noise)

            cs = self.compute_conditioners(sigma)

            # TODO: randomly mask agent action, note that this process should be independent of noise level FINISH:
            activate_action_indices = torch.randint(0, self.num_agents, (b,), device=self.device)
            act_mask = torch.ones(b, t - 1, self.num_agents, device=self.device, dtype=torch.long)
            act_mask = torch.cat((act_mask, F.one_hot(activate_action_indices, num_classes=self.num_agents).unsqueeze(1).to(torch.long)), dim=1)
            model_output = self.compute_model_output(noisy_next_obs, obs, act, cs, act_mask)

            target = (next_obs - cs.c_skip * noisy_next_obs) / cs.c_out
            loss += F.mse_loss(model_output[mask], target[mask])
            denoised = self.wrap_model_output(noisy_next_obs, model_output, cs)
            all_obs[:, n + i] = denoised.squeeze(1)

        loss /= seq_length
        return loss, {"loss_denoising": loss.detach()}
