import torch
import torch.nn as nn

from torch import Tensor

from tqdm import tqdm
from torch import Tensor
from typing import Tuple

from models.Unet.unet import Unet3D

__all__ = [
    'Diffusion'
]

def sw_pde_loss(
        data: Tensor,
        dx: float, dy: float, dt: float,
        gdr: Tensor, coriolis: Tensor
    ):
    batch_size = data.size(0)
    # gravity.shape, depth.shape == (batch_size, )
    gravity, depth = gdr[:, :2].permute(1, 0)
    gravity = gravity.reshape(-1, 1, 1, 1)
    depth = depth.reshape(-1, 1, 1, 1)


    # h.shape, u.shape, v.shape == (batch_size, n_time, img_size, img_size)
    h, u, v = data.permute(2, 0, 1, 3, 4)
    h = h + depth

    v_avg = 0.25 * (v[:, :, 1:-1, 1:-1] + v[:, :, :-2, 1:-1] + v[:, :, 1:-1, 2:] + v[:, :, :-2, 2:])
    u_avg = 0.25 * (u[:, :, 1:-1, 1:-1] + u[:, :, 1:-1, :-2] + u[:, :, 2:, 1:-1] + u[:, :, 2:, :-2])

    dudt = torch.diff(u, dim=1)[:, :, 1:-1, 1:-1] / dt
    dvdt = torch.diff(v, dim=1)[:, :, 1:-1, 1:-1] / dt
    dhdt = torch.diff(h, dim=1)[:, :, 1:-1, 1:-1] / dt

    dhdx = (h[:, :, 1:-1, 2:] - h[:, :, 1:-1, 1:-1]) / dx
    dhdy = (h[:, :, 2:, 1:-1] - h[:, :, 1:-1, 1:-1]) / dy
    dudx = (u[:, :, 1:-1, 1:-1] - u[:, :, 1:-1, :-2]) / dx
    dvdy = (v[:, :, 1:-1, 1:-1] - v[:, :, :-2, 1:-1]) / dy

    loss1 = dudt - (coriolis * v_avg - gravity * dhdx)[:, :-1]
    loss2 = dvdt + coriolis * u_avg[:, 1:] + gravity * dhdy[:, :-1]
    loss3 = dhdt + depth * (dudx + dvdy)[:, 1:]
    loss = 1e5 * torch.stack([loss1, loss2, loss3], dim=1).square()
    loss = torch.clamp(loss, max=1.0)
    return loss.reshape(batch_size, -1)

class Diffusion(nn.Module):
    def __init__(
        self,   
        config,
        dx, dy, dt,
        coriolis_param,
        **kwargs
    ):
        super().__init__()
        self.dx, self.dy, self.dt = dx, dy, dt
        self.coriolis = coriolis_param

        assert config.diffusion in ['vp'], 'diffusion noise scheduler not implemented'

        if config.loss == 'naive':
            self.loss = config.loss
        elif config.loss.startswith('pde'):
            self.loss = 'pde'
            self.loss_weight = float(config.loss.split('_')[-1])
        else:
            raise NotImplementedError
            
        self.steps = 1000
        self.config = config
        self.device = config.device

        self.alpha_org = 1 - torch.linspace(1e-4, 1e-2, self.steps, device=self.device)
        self.alpha_t = self.alpha_org.cumprod(dim=0).sqrt() 
        self.sigma_t = (1 - self.alpha_t.square()).sqrt()
        self.lambda_t = (self.alpha_t / self.sigma_t).log()
        self.diffusion_weight = (-2 * self.alpha_t.log()).diff() * self.steps
        self.diffusion_weight = torch.concatenate([self.diffusion_weight, self.diffusion_weight[-1].reshape(-1)])

        self.sampling_cfg = config.sampling

        self.network = Unet3D(**config.network).to(self.device)

        
        print(f"{self.network._get_name()} #Params: {sum(p.numel() for p in self.network.parameters())}")

    def diff_forward(self, x0: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        extend_dim = [-1] + [1] * (x0.ndim - 1)
        t = torch.randint(0, self.steps, (x0.size(0), ), device=self.device)

        noise = torch.randn_like(x0)
        alpha = self.alpha_t[t].reshape(extend_dim)
        sigma = self.sigma_t[t].reshape(extend_dim)
        return alpha * x0 + sigma * noise, t, noise

    def get_loss(self, gdr: Tensor, x0: Tensor) -> Tuple[dict[str, Tensor], int]:
        # x0.shape == (batch_size, n_time, img_size, img_size)
        # condition.shape == (batch_size, img_size, img_size)
        loss_dict = {}
        xt, t, eps_noise = self.diff_forward(x0)
        
        model_out: Tensor = self.network(xt, t, gdr)

        assert eps_noise.shape == model_out.shape

        batch_size = xt.size(0)
        loss_w = self.diffusion_weight[t].reshape(-1, 1)

        score_loss = (x0 - model_out).square().reshape(batch_size, -1)
        loss_dict['score'] = (loss_w * score_loss).mean()

        pde_loss = sw_pde_loss(model_out, self.dx, self.dy, self.dt, gdr, self.coriolis)

        loss_dict['pde'] = (loss_w * pde_loss).mean()

        if self.loss == 'naive':
            loss = score_loss
        elif self.loss == 'pde':
            loss = torch.concatenate([
                score_loss,
                self.loss_weight * score_loss.size(-1) / pde_loss.size(-1) * pde_loss
            ], dim=-1)
        else:
            raise NotImplementedError()

        loss_dict['total'] = (loss_w * loss).mean()

        return loss_dict, loss.size(0)

    @torch.no_grad()
    def network_predict(self, xt: Tensor, t: Tensor, gdr: Tensor) -> Tensor:
        model_predict = self.network(xt, t, gdr)
        alpha, sigma = self.alpha_t[t].reshape(-1, *[1] * (xt.ndim - 1)), self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1))
        return (xt - alpha * model_predict) / sigma
        
    @torch.no_grad()
    def ode_reverse(self, s: Tensor, t: Tensor, xt: Tensor, gdr: Tensor) -> Tensor:
        eps_model = self.network_predict(xt, s.repeat(xt.size(0)), gdr)

        model_out = self.sigma_t[t] * torch.expm1(self.lambda_t[t] - self.lambda_t[s]) * eps_model
        return self.alpha_t[t] / self.alpha_t[s] * xt - model_out
    
    def sampling(self, data_shape: Tuple, gdr: Tensor, **kwargs) -> Tensor:
        timestamps = torch.arange(0, self.steps)

        timestamps = timestamps.__reversed__().int().to(self.device)
        xt = torch.randn(size=data_shape, device=self.device)
        pbar = tqdm(zip(timestamps[:-1], timestamps[1:]), leave=False, total=len(timestamps)-1, dynamic_ncols=True)
        for s, t in pbar:
            xt = self.ode_reverse(s, t, xt, gdr)
            
        return xt





        