import torch
import torch.nn as nn

from torch import Tensor

from tqdm import tqdm
from typing import Tuple


def pde_loss(u: Tensor, gt_u: Tensor, dx: Tensor, dt: Tensor) -> Tensor:
    assert u.ndim == 3
    gt_u_t = torch.diff(gt_u, dim=1) / dt
    gt_u_x = torch.diff(gt_u, dim=2) / dx
    loss = gt_u_t[:, :, :-1] + u[:, :-1, :-1] * gt_u_x[:, :-1]
    return loss.square()

class Diffusion(nn.Module):
    def __init__(
        self,   
        config,
    ):
        super().__init__()
        assert config.diffusion in ['ve', 'vp'], 'diffusion noise scheduler not implemented'

        if config.loss == 'naive':
            self.loss = config.loss
        elif config.loss.startswith('pde'):
            self.loss = 'pde'
            self.loss_weight = float(config.loss.split('_')[-1])
        else:
            raise NotImplementedError
            
        self.steps = 1000
        self.config = config
        self.device = config.device

        if config.diffusion == 've':
            self.alpha_t = torch.ones(size=(self.steps, ), device=self.device)
            self.sigma_t = torch.sigmoid(torch.linspace(-5, 5, self.steps, device=self.device))
            self.sigma_t = 5 * (self.sigma_t - self.sigma_t.min() + 1e-5)
            self.lambda_t = (self.alpha_t / self.sigma_t).log()
            self.diffusion_weight = self.sigma_t.square()
        
        else:
            self.alpha_org = 1 - torch.linspace(1e-4, 1e-2, self.steps, device=self.device)
            self.alpha_t = self.alpha_org.cumprod(dim=0).sqrt() 
            self.sigma_t = (1 - self.alpha_t.square()).sqrt()
            self.lambda_t = (self.alpha_t / self.sigma_t).log()
            self.diffusion_weight = (-2 * self.alpha_t.log()).diff() * self.steps
            self.diffusion_weight = torch.concatenate([self.diffusion_weight, self.diffusion_weight[-1].reshape(-1)])

        self.sampling_cfg = config.sampling

        backbone = config.model.lower()

        if backbone == 'karras':
            from models.KarrasUnet.unet import Unet
            self.network = Unet
        else:
            raise NotImplementedError('backbone not found')

        self.network = self.network(
            n_diff_time=self.steps,
            **config.network
        ).to(self.device)

        assert config.model_out in ['x0']
        self.model_out = config.model_out
        
        n = 128
        ts = torch.linspace(0, 3, n)
        self.dt = (ts[1] - ts[0]).to(torch.float32).to(self.device)
        xs = torch.linspace(-4, 4, n)
        self.dx = (xs[1] - xs[0]).to(torch.float32).to(self.device)

        print(f"{self.network._get_name()} #Params: {sum(p.numel() for p in self.network.parameters())}")

    def diff_forward(self, x0: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        extend_dim = [-1] + [1] * (x0.ndim - 1)
        t = torch.randint(0, self.steps, (x0.size(0), ), device=self.device)

        noise = torch.randn_like(x0)
        alpha = self.alpha_t[t].reshape(extend_dim)
        sigma = self.sigma_t[t].reshape(extend_dim)
        return alpha * x0 + sigma * noise, t, noise

    def get_loss(self, x0: Tensor) -> Tuple[dict, int]:
        # x0.shape == (batch_size, input_size, input_size)
        loss_dict = {}
        xt, t, eps_noise = self.diff_forward(x0)
        
        if self.config.diffusion == 've':
            t = torch.ones_like(t)
        model_out: Tensor = self.network(xt, t)

        assert eps_noise.shape == model_out.shape

        batch_size = xt.size(0)
        loss_w = self.diffusion_weight[t].reshape(-1, 1)

        if self.model_out == 'x0':
            score_loss = (x0 - model_out).square().reshape(batch_size, -1)
            if self.loss == 'pde':
                de_loss = pde_loss(model_out, x0, dx=self.dx, dt=self.dt).reshape(batch_size, -1)
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError

        loss_dict['pde_de'] = (loss_w * de_loss).mean() 
        loss_dict['score'] = (loss_w * score_loss).mean()

        # print(loss_dict['pde_de'], loss_dict['score'])

        if self.loss == 'naive':
            loss = score_loss
        elif self.loss == 'pde' or self.loss == 'jensen':
            loss = torch.concatenate([
                score_loss,
                self.loss_weight * de_loss
            ], dim=-1)

        else:
            raise NotImplementedError()

        loss_dict['total'] = (loss_w * loss).mean()
        return loss_dict, loss.size(0)

    @torch.no_grad()
    def network_predict(self, xt: Tensor, t: Tensor) -> Tensor:
        model_predict = self.network(xt, t)
        if self.model_out == 'noise':
            return model_predict
        else:
            alpha, sigma = self.alpha_t[t].reshape(-1, *[1] * (xt.ndim - 1)), self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1))
            return (xt - alpha * model_predict) / sigma
        
    @torch.no_grad()
    def ode_reverse(self, s: Tensor, t: Tensor, xt: Tensor) -> Tensor:
        eps_model = self.network_predict(xt, s.repeat(xt.size(0)))

        model_out = self.sigma_t[t] * torch.expm1(self.lambda_t[t] - self.lambda_t[s]) * eps_model
        return self.alpha_t[t] / self.alpha_t[s] * xt - model_out
    
    def sampling(self, data_shape: Tuple, method: str | None = None) -> Tensor:
        if method is None and self.config.diffusion == 'vp':
            method = self.sampling_cfg.method

        if self.config.diffusion == 've':
            method = 'ld'
            itertor = self.ld_reverse
            timestamps = torch.arange(0, self.steps)
        else:
            if method == 'ddpm':
                itertor = self.ddpm_reverse
                timestamps = torch.arange(-1, self.steps)
            elif method == 'ode':
                itertor = self.ode_reverse
                timestamps = torch.arange(0, self.steps)
            elif method == 'dpm3':
                itertor = self.dpm3_reverse
                timestamps = torch.linspace(0, self.steps - 1, 30)
            else:
                raise NotImplementedError('sampling method not found')
        
        timestamps = timestamps.__reversed__().int().to(self.device)

        if self.config.diffusion == 'vp':
            xt = torch.randn(size=data_shape, device=self.device)
        else:
            xt = self.sigma_t[-1] * torch.randn(size=data_shape, device=self.device)

        pbar = tqdm(zip(timestamps[:-1], timestamps[1:]), leave=False, total=len(timestamps)-1, dynamic_ncols=True)
        for s, t in pbar:
            xt = itertor(s, t, xt)
            
        return xt





        