import torch
import torch.nn as nn

from torch import Tensor

from tqdm import tqdm
from typing import Tuple

def adv_1d_loss(pred: Tensor, dx: Tensor, dt: Tensor) -> Tensor:
    batch_size = pred.size(0)
    beta = 0.1

    du_t = (pred[:, :, 2:] - pred[:, :, :-2]) / (2*dt)
    du_x = (pred[:, 2:] - pred[:, :-2]) / (2*dx)
    de_loss = (du_t[:, 1:-1] + beta * du_x[:, :, 1:-1]).square().reshape(batch_size, -1)
    return de_loss


class Diffusion(nn.Module):
    def __init__(
        self,   
        config,
        dx: Tensor, dt: Tensor,
    ):
        super().__init__()
        assert config.diffusion in ['ve', 'vp'], 'diffusion noise scheduler not implemented'

        if config.loss == 'naive':
            self.loss = config.loss
        elif config.loss.startswith('pde'):
            self.loss = 'pde'
            self.loss_weight = float(config.loss.split('_')[-1])
        else:
            raise NotImplementedError
            
        self.steps = 1000
        self.config = config
        self.device = config.device

        if config.diffusion == 've':
            self.alpha_t = torch.ones(size=(self.steps, ), device=self.device)
            self.sigma_t = torch.sigmoid(torch.linspace(-5, 5, self.steps, device=self.device))
            self.sigma_t = 5 * (self.sigma_t - self.sigma_t.min() + 1e-5)
            self.lambda_t = (self.alpha_t / self.sigma_t).log()
            self.diffusion_weight = self.sigma_t.square()
        
        else:
            self.alpha_org = 1 - torch.linspace(1e-4, 1e-2, self.steps, device=self.device)
            self.alpha_t = self.alpha_org.cumprod(dim=0).sqrt() 
            self.sigma_t = (1 - self.alpha_t.square()).sqrt()
            self.lambda_t = (self.alpha_t / self.sigma_t).log()
            self.diffusion_weight = (-2 * self.alpha_t.log()).diff() * self.steps
            self.diffusion_weight = torch.concatenate([self.diffusion_weight, self.diffusion_weight[-1].reshape(-1)])

        self.sampling_cfg = config.sampling

        backbone = config.model.lower()

        if backbone == 'gru':
            from models.GRU.gru import GRU
            self.network = GRU

        else:
            raise NotImplementedError('backbone not found')

        self.network = self.network(
            n_diff_time=self.steps,
            **config.network
        ).to(self.device)

        print(f"{self.network._get_name()} #Params: {sum(p.numel() for p in self.network.parameters())}")

        self.dx, self.dt = dx.to(self.device), dt.to(self.device)
        
        assert config.model_out in ['noise', 'x0']
        self.model_out = config.model_out

    def diff_forward(self, x0: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        extend_dim = [-1] + [1] * (x0.ndim - 1)
        t = torch.randint(0, self.steps, (x0.size(0), ), device=self.device)

        noise = torch.randn_like(x0)
        alpha = self.alpha_t[t].reshape(extend_dim)
        sigma = self.sigma_t[t].reshape(extend_dim)
        return alpha * x0 + sigma * noise, t, noise

    def get_loss(self, x0: Tensor, grid: Tensor) -> Tuple[dict, int]:
        # x0.shape == (batch_size, input_size, n_time, 1)
        loss_dict = {}
        xt, t, eps_noise = self.diff_forward(x0)
        
        if self.config.diffusion == 've':
            t = torch.ones_like(t)
        eps_model: Tensor = self.network(xt, t, grid[0])
        assert eps_noise.shape == eps_model.shape

        batch_size = xt.size(0)
        loss_w = self.diffusion_weight[t].reshape(-1, 1)
        
        if self.model_out == 'noise':
            score_loss = (eps_noise - eps_model).square().reshape(batch_size, -1)
            conditon_x0_xt = xt - self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1)) * eps_model
            conditon_x0_xt = conditon_x0_xt / self.alpha_t[t].reshape(-1, *[1] * (xt.ndim - 1))
            de_loss = adv_1d_loss(conditon_x0_xt, self.dx, self.dt)
        else:
            score_loss = (x0 - eps_model).square().reshape(batch_size, -1)
            de_loss = adv_1d_loss(eps_model, self.dx, self.dt)

            loss_dict['pde_de'] = (loss_w * de_loss).mean() 

        loss_dict['pde_de'] = (loss_w * de_loss).mean() 
        loss_dict['score'] = (loss_w * score_loss).mean()

        if self.loss == 'naive':
            loss = score_loss
        elif self.loss == 'pde':
            loss = torch.concatenate([
                score_loss,
                self.loss_weight * de_loss
            ], dim=-1)

        else:
            raise NotImplementedError()

        loss_dict['total'] = (loss_w * loss).mean()
        return loss_dict, loss.size(0)

    @torch.no_grad()
    def network_predict(self, xt: Tensor, t: Tensor, grid: Tensor) -> Tensor:
        model_predict = self.network(xt, t, grid)
        if self.model_out == 'noise':
            return model_predict
        else:
            alpha, sigma = self.alpha_t[t].reshape(-1, *[1] * (xt.ndim - 1)), self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1))
            return (xt - alpha * model_predict) / sigma
        
    @torch.no_grad()
    def ode_reverse(self, s: Tensor, t: Tensor, xt: Tensor, grid: Tensor) -> Tensor:
        eps_model = self.network_predict(xt, s.repeat(xt.size(0)), grid)

        model_out = self.sigma_t[t] * torch.expm1(self.lambda_t[t] - self.lambda_t[s]) * eps_model
        return self.alpha_t[t] / self.alpha_t[s] * xt - model_out
    
    @torch.no_grad()
    def ddpm_reverse(self, s: Tensor, t: Tensor, xt: Tensor, grid: Tensor) -> Tensor:
        eps_model = self.network_predict(xt, s.repeat(xt.size(0)), grid)
        eps_model = (1 - self.alpha_org[s]) / (1 - self.alpha_t[s]).sqrt() * eps_model

        mu = (xt - eps_model) / self.alpha_org[s].sqrt()
        sigma2 = (1 - self.alpha_t[t]) / (1 - self.alpha_t[s]) * (1 - self.alpha_org[s])
        # sigma2 = 1 - self.alpha_org[s]
        return mu + sigma2.sqrt() * torch.randn_like(mu)

    @torch.no_grad()
    def dpm2_reverse(self, s: Tensor, t: Tensor, xt: Tensor, grid: Tensor) -> Tensor:
        mid_t = torch.argmin(((self.lambda_t[s] + self.lambda_t[t]) / 2 - self.lambda_t).abs())
        h = self.lambda_t[t] - self.lambda_t[s]

        u = self.alpha_t[mid_t] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_t] * torch.expm1(h/2) * self.network_predict(xt, s.repeat(xt.size(0)), grid)
        
        return self.alpha_t[t] / self.alpha_t[s] * xt \
            - self.sigma_t[t] * torch.expm1(h) * self.network_predict(u, mid_t.repeat(xt.size(0)), grid)

    @torch.no_grad()
    def dpm3_reverse(self, s: Tensor, t: Tensor, xt: Tensor, grid: Tensor) -> Tensor:
        r1, r2, h = 1/3, 2/3, self.lambda_t[t] - self.lambda_t[s]
        mid_1 = torch.argmin((self.lambda_t[s] + r1 * h - self.lambda_t).abs())
        mid_2 = torch.argmin((self.lambda_t[s] + r2 * h - self.lambda_t).abs())

        eps_s = self.network_predict(xt, s.repeat(xt.size(0)), grid)
        u1 = self.alpha_t[mid_1] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_1] * torch.expm1(r1 * h) * eps_s
        
        d1 = self.network_predict(u1, mid_1.repeat(xt.size(0)), grid) - eps_s

        u2 = self.alpha_t[mid_2] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_2] * torch.expm1(r2 * h) * eps_s \
            - self.sigma_t[mid_2] * r2 / r1 * (torch.expm1(r2 * h) / (r2 * h) - 1) * d1
        
        d2 = self.network_predict(u2, mid_2.repeat(xt.size(0)), grid) - eps_s

        return self.alpha_t[t] / self.alpha_t[s] * xt \
            - self.sigma_t[t] * torch.expm1(h) * eps_s \
            - self.sigma_t[t] / r2 * (torch.expm1(h) / h - 1) * d2
    
    @torch.no_grad()
    def ld_reverse(self, s: Tensor, t: Tensor, xt: Tensor, grid: Tensor) -> Tensor:
        sigma_t = self.sigma_t[s]
        eps_model = self.network_predict(
            xt,
            torch.ones(size=(xt.size(0), ), device=xt.device).to(torch.long),
            grid
        )
        step_size = 0.001 * sigma_t.square()
        xt = xt + step_size * eps_model / sigma_t + torch.sqrt(step_size*2) * torch.rand_like(xt)
        return xt
    
    @torch.no_grad()
    def sampling(self, data_shape: Tuple, grid: Tensor, method: str | None = None) -> Tensor:
        if method is None and self.config.diffusion == 'vp':
            method = self.sampling_cfg.method

        if self.config.diffusion == 've':
            method = 'ld'
            itertor = self.ld_reverse
            timestamps = torch.arange(0, self.steps)
        else:
            if method == 'ddpm':
                itertor = self.ddpm_reverse
                timestamps = torch.arange(-1, self.steps)
            elif method == 'ode':
                itertor = self.ode_reverse
                timestamps = torch.arange(0, self.steps)
            elif method == 'dpm3':
                itertor = self.dpm3_reverse
                timestamps = torch.linspace(0, self.steps - 1, 30)
            else:
                raise NotImplementedError('sampling method not found')
        
        timestamps = timestamps.__reversed__().int().to(self.device)

        if self.config.diffusion == 'vp':
            xt = torch.randn(size=data_shape, device=self.device)
        else:
            xt = self.sigma_t[-1] * torch.randn(size=data_shape, device=self.device)

        pbar = tqdm(zip(timestamps[:-1], timestamps[1:]), leave=False, total=len(timestamps)-1, dynamic_ncols=True)
        for s, t in pbar:
            xt = itertor(s, t, xt, grid)
            
        return xt



        