import torch
import torch.nn as nn

from torch import Tensor

from tqdm import tqdm
from typing import Tuple


def darcy_loss(pred: Tensor, condition: Tensor, reciprocal_dx: Tensor) -> Tensor:
    beta = 1.0

    du_x_ = torch.diff(pred, dim=-2) * reciprocal_dx
    du_y_ = torch.diff(pred, dim=-1) * reciprocal_dx
    ax_ = (condition[:, :-1] + condition[:, 1:]) / 2
    ay_ = (condition[:, :, 1:] + condition[:, :, :-1]) / 2

    Df = - (
        (ax_ * du_x_)[:, 1:, 1:-1] - (ax_ * du_x_)[:, :-1, 1:-1] \
        + (ay_ * du_y_)[:, 1:-1, 1:] - (ay_ * du_y_)[:, 1:-1, :-1]
    ) * reciprocal_dx - beta
    return Df.square().log1p()


class Diffusion(nn.Module):
    def __init__(
        self,   
        config,
    ):
        super().__init__()
        assert config.diffusion in ['ve', 'vp'], 'diffusion noise scheduler not implemented'

        if config.loss == 'naive':
            self.loss = config.loss
        elif config.loss.startswith('pde'):
            self.loss = 'pde'
            self.loss_weight = float(config.loss.split('_')[-1])
        else:
            raise NotImplementedError
            
        self.steps = 1000
        self.config = config
        self.device = config.device

        if config.diffusion == 've':
            self.alpha_t = torch.ones(size=(self.steps, ), device=self.device)
            self.sigma_t = torch.sigmoid(torch.linspace(-5, 5, self.steps, device=self.device))
            self.sigma_t = 5 * (self.sigma_t - self.sigma_t.min() + 1e-5)
            self.lambda_t = (self.alpha_t / self.sigma_t).log()
            self.diffusion_weight = self.sigma_t.square()
        
        else:
            self.alpha_org = 1 - torch.linspace(1e-4, 1e-2, self.steps, device=self.device)
            self.alpha_t = self.alpha_org.cumprod(dim=0).sqrt() 
            self.sigma_t = (1 - self.alpha_t.square()).sqrt()
            self.lambda_t = (self.alpha_t / self.sigma_t).log()
            self.diffusion_weight = (-2 * self.alpha_t.log()).diff() * self.steps
            self.diffusion_weight = torch.concatenate([self.diffusion_weight, self.diffusion_weight[-1].reshape(-1)])

        self.sampling_cfg = config.sampling

        backbone = config.model.lower()

        if backbone == 'karras':
            from models.KarrasUnet.unet import Unet
            self.network = Unet
        elif backbone == 'unet':
            from models.Unet.unet import Unet
            self.network = Unet
        else:
            raise NotImplementedError('backbone not found')

        self.network = self.network(
            n_diff_time=self.steps,
            **config.network
        ).to(self.device)

        assert config.model_out in ['noise', 'x0']
        self.model_out = config.model_out
        
        print(f"{self.network._get_name()} #Params: {sum(p.numel() for p in self.network.parameters())}")

    def diff_forward(self, x0: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        extend_dim = [-1] + [1] * (x0.ndim - 1)
        t = torch.randint(0, self.steps, (x0.size(0), ), device=self.device)

        noise = torch.randn_like(x0)
        alpha = self.alpha_t[t].reshape(extend_dim)
        sigma = self.sigma_t[t].reshape(extend_dim)
        return alpha * x0 + sigma * noise, t, noise

    def get_loss(self, x0: Tensor, condition: Tensor) -> Tuple[dict, int]:
        # x0.shape == (batch_size, input_size, input_size)
        # condition.shape == (batch_size, input_size, input_size)
        loss_dict = {}
        xt, t, eps_noise = self.diff_forward(x0)
        

        if self.config.diffusion == 've':
            t = torch.ones_like(t)
        model_out: Tensor = self.network(xt, t, condition)

        assert eps_noise.shape == model_out.shape

        batch_size = xt.size(0)
        loss_w = self.diffusion_weight[t].reshape(-1, 1)

        rep_dx = xt.size(1)
        if self.model_out == 'noise':
            score_loss = (eps_noise - model_out).square().reshape(batch_size, -1)
            conditon_x0_xt = xt - self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1)) * model_out
            conditon_x0_xt = conditon_x0_xt / self.alpha_t[t].reshape(-1, *[1] * (xt.ndim - 1))
            de_loss = darcy_loss(conditon_x0_xt, condition, rep_dx).reshape(batch_size, -1)
        else:
            score_loss = (x0 - model_out).square().reshape(batch_size, -1)
            de_loss = darcy_loss(model_out, condition, rep_dx).reshape(batch_size, -1)

        de_loss = de_loss / rep_dx
        loss_dict['pde_de'] = (loss_w * de_loss).mean() 
        loss_dict['score'] = (loss_w * score_loss).mean()

        # print(loss_dict['pde_de'], loss_dict['score'])

        if self.loss == 'naive':
            loss = score_loss
        elif self.loss == 'pde':
            loss = torch.concatenate([
                score_loss,
                self.loss_weight * de_loss
            ], dim=-1)

        else:
            raise NotImplementedError()

        loss_dict['total'] = (loss_w * loss).mean()
        return loss_dict, loss.size(0)

    @torch.no_grad()
    def network_predict(self, xt: Tensor, t: Tensor, condition: Tensor) -> Tensor:
        model_predict = self.network(xt, t, condition)
        if self.model_out == 'noise':
            return model_predict
        else:
            alpha, sigma = self.alpha_t[t].reshape(-1, *[1] * (xt.ndim - 1)), self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1))
            return (xt - alpha * model_predict) / sigma
        
    @torch.no_grad()
    def ode_reverse(self, s: Tensor, t: Tensor, xt: Tensor, condition: Tensor) -> Tensor:
        eps_model = self.network_predict(xt, s.repeat(xt.size(0)), condition)

        model_out = self.sigma_t[t] * torch.expm1(self.lambda_t[t] - self.lambda_t[s]) * eps_model
        return self.alpha_t[t] / self.alpha_t[s] * xt - model_out
    
    @torch.no_grad()
    def ddpm_reverse(self, s: Tensor, t: Tensor, xt: Tensor, condition: Tensor) -> Tensor:
        eps_model = self.network_predict(xt, s.repeat(xt.size(0)), condition)
        eps_model = (1 - self.alpha_org[s]) / (1 - self.alpha_t[s]).sqrt() * eps_model

        mu = (xt - eps_model) / self.alpha_org[s].sqrt()
        sigma2 = (1 - self.alpha_t[t]) / (1 - self.alpha_t[s]) * (1 - self.alpha_org[s])
        # sigma2 = 1 - self.alpha_org[s]
        return mu + sigma2.sqrt() * torch.randn_like(mu)

    @torch.no_grad()
    def dpm2_reverse(self, s: Tensor, t: Tensor, xt: Tensor, condition: Tensor) -> Tensor:
        mid_t = torch.argmin(((self.lambda_t[s] + self.lambda_t[t]) / 2 - self.lambda_t).abs())
        h = self.lambda_t[t] - self.lambda_t[s]

        u = self.alpha_t[mid_t] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_t] * torch.expm1(h/2) * self.network_predict(xt, s.repeat(xt.size(0)), condition)
        
        return self.alpha_t[t] / self.alpha_t[s] * xt \
            - self.sigma_t[t] * torch.expm1(h) * self.network_predict(u, mid_t.repeat(xt.size(0)), condition)

    @torch.no_grad()
    def dpm3_reverse(self, s: Tensor, t: Tensor, xt: Tensor, condition: Tensor) -> Tensor:
        r1, r2, h = 1/3, 2/3, self.lambda_t[t] - self.lambda_t[s]
        mid_1 = torch.argmin((self.lambda_t[s] + r1 * h - self.lambda_t).abs())
        mid_2 = torch.argmin((self.lambda_t[s] + r2 * h - self.lambda_t).abs())

        eps_s = self.network_predict(xt, s.repeat(xt.size(0)), condition)
        u1 = self.alpha_t[mid_1] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_1] * torch.expm1(r1 * h) * eps_s
        
        d1 = self.network_predict(u1, mid_1.repeat(xt.size(0)), condition) - eps_s

        u2 = self.alpha_t[mid_2] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_2] * torch.expm1(r2 * h) * eps_s \
            - self.sigma_t[mid_2] * r2 / r1 * (torch.expm1(r2 * h) / (r2 * h) - 1) * d1
        
        d2 = self.network_predict(u2, mid_2.repeat(xt.size(0)), condition) - eps_s

        return self.alpha_t[t] / self.alpha_t[s] * xt \
            - self.sigma_t[t] * torch.expm1(h) * eps_s \
            - self.sigma_t[t] / r2 * (torch.expm1(h) / h - 1) * d2
    
    @torch.no_grad()
    def ld_reverse(self, s: Tensor, t: Tensor, xt: Tensor, condition: Tensor) -> Tensor:
        sigma_t = self.sigma_t[s]
        eps_model = self.network_predict(
            xt,
            torch.ones(size=(xt.size(0), ), device=xt.device).to(torch.long),
            condition
        )
        step_size = 0.001 * sigma_t.square()
        xt = xt + step_size * eps_model / sigma_t + torch.sqrt(step_size*2) * torch.rand_like(xt)
        return xt
    
    def sampling(self, data_shape: Tuple, condition: Tensor, method: str | None = None) -> Tensor:
        if method is None and self.config.diffusion == 'vp':
            method = self.sampling_cfg.method

        if self.config.diffusion == 've':
            method = 'ld'
            itertor = self.ld_reverse
            timestamps = torch.arange(0, self.steps)
        else:
            if method == 'ddpm':
                itertor = self.ddpm_reverse
                timestamps = torch.arange(-1, self.steps)
            elif method == 'ode':
                itertor = self.ode_reverse
                timestamps = torch.arange(0, self.steps)
            elif method == 'dpm3':
                itertor = self.dpm3_reverse
                timestamps = torch.linspace(0, self.steps - 1, 30)
            else:
                raise NotImplementedError('sampling method not found')
        
        timestamps = timestamps.__reversed__().int().to(self.device)

        if self.config.diffusion == 'vp':
            xt = torch.randn(size=data_shape, device=self.device)
        else:
            xt = self.sigma_t[-1] * torch.randn(size=data_shape, device=self.device)

        pbar = tqdm(zip(timestamps[:-1], timestamps[1:]), leave=False, total=len(timestamps)-1, dynamic_ncols=True)
        for s, t in pbar:
            xt = itertor(s, t, xt, condition)
            
        return xt





        