import torch
import torch.nn as nn

from torch import Tensor

from tqdm import tqdm
from typing import Tuple

def adv_1d_loss(pred: Tensor, u: Tensor, dx: Tensor, dt: Tensor) -> Tuple[Tensor, Tensor]:
    batch_size = pred.size(0)
    beta = 0.1
    ic_loss = (pred[:, :, 0] - u[:, :, 0]).square().reshape(batch_size, -1)

    du_t = (pred[:, :, 2:] - pred[:, :, :-2]) / (2*dt)
    du_x = (pred[:, 2:] - pred[:, :-2]) / (2*dx)
    de_loss = (du_t[:, 1:-1] + beta * du_x[:, :, 1:-1]).square().reshape(batch_size, -1)
    return ic_loss, de_loss


class Diffusion(nn.Module):
    def __init__(
        self,   
        config,
        dx: Tensor, dt: Tensor,
    ):
        super().__init__()
        assert config.diffusion in ['ve', 'vp'], 'diffusion noise scheduler not implemented'

        if config.loss == 'naive':
            self.loss = config.loss
        elif config.loss.startswith('pde'):
            self.loss = 'pde'
            self.loss_weight = list(map(float, config.loss.split('_')[1:]))
        else:
            raise NotImplementedError
            
        self.steps = 1000
        self.config = config
        self.device = config.device

        if config.diffusion == 've':
            self.alpha_t = torch.ones(size=(self.steps, ), device=self.device)
            self.sigma_t = torch.sigmoid(torch.linspace(-5, 5, self.steps, device=self.device))
            self.sigma_t = 5 * (self.sigma_t - self.sigma_t.min() + 1e-5)
            self.lambda_t = (self.alpha_t / self.sigma_t).log()
            self.diffusion_weight = self.sigma_t.square()
        
        else:
            self.alpha_org = 1 - torch.linspace(1e-4, 1e-2, self.steps, device=self.device)
            self.alpha_t = self.alpha_org.cumprod(dim=0).sqrt() 
            self.sigma_t = (1 - self.alpha_t.square()).sqrt()
            self.lambda_t = (self.alpha_t / self.sigma_t).log()
            self.diffusion_weight = (-2 * self.alpha_t.log()).diff() * self.steps
            self.diffusion_weight = torch.concatenate([self.diffusion_weight, self.diffusion_weight[-1].reshape(-1)])

        self.sampling_cfg = config.sampling

        backbone = config.model.lower()

        if backbone == 'gru':
            from models.GRU.gru import GRU
            self.network = GRU
        elif backbone == 'paragru':
            from models.ParaGRU.paragru import ParaGRU
            self.network = ParaGRU
        # elif backbone == 'deeponet':
        #     from models.DeepONet.deeponet import DeepONetCartesianProd1D
        #     self.network = DeepONetCartesianProd1D
        else:
            raise NotImplementedError('backbone not found')

        self.network = self.network(
            n_diff_time=self.steps,
            **config.network
        ).to(self.device)

        print(f"{self.network._get_name()} #Params: {sum(p.numel() for p in self.network.parameters())}")

        self.dx, self.dt = dx.to(self.device), dt.to(self.device)
        
        assert config.model_out in ['noise', 'x0']
        self.model_out = config.model_out

    def diff_forward(self, x0: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        extend_dim = [-1] + [1] * (x0.ndim - 1)
        t = torch.randint(0, self.steps, (x0.size(0), ), device=self.device)

        noise = torch.randn_like(x0)
        alpha = self.alpha_t[t].reshape(extend_dim)
        sigma = self.sigma_t[t].reshape(extend_dim)
        return alpha * x0 + sigma * noise, t, noise

    def get_loss(self, x0: Tensor, condition: Tensor, grid: Tensor) -> Tuple[dict, int]:
        # x0.shape == (batch_size, input_size, n_time, 1)
        # condition.shape == (batch_size, input_size, 1, 1)
        loss_dict = {}
        xt, t, eps_noise = self.diff_forward(x0)
        
        if self.config.diffusion == 've':
            t = torch.ones_like(t)
        eps_model: Tensor = self.network(xt, t, condition, grid[0])
        assert eps_noise.shape == eps_model.shape

        batch_size = xt.size(0)
        loss_w = self.diffusion_weight[t].reshape(-1, 1)
        
        if self.model_out == 'noise':
            score_loss = (eps_noise - eps_model).square().reshape(batch_size, -1)
            conditon_x0_xt = xt - self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1)) * eps_model
            conditon_x0_xt = conditon_x0_xt / self.alpha_t[t].reshape(-1, *[1] * (xt.ndim - 1))
            ic_loss, de_loss = adv_1d_loss(conditon_x0_xt, x0, self.dx, self.dt)
        else:
            score_loss = (x0 - eps_model).square().reshape(batch_size, -1)
            ic_loss, de_loss = adv_1d_loss(eps_model, x0, self.dx, self.dt)

            loss_dict['pde_ic'] = (loss_w * ic_loss).mean() 
            loss_dict['pde_de'] = (loss_w * de_loss).mean() 


        loss_dict['pde_ic'] = (loss_w * ic_loss).mean() 
        loss_dict['pde_de'] = (loss_w * de_loss).mean() 
        loss_dict['score'] = (loss_w * score_loss).mean()

        if self.loss == 'naive':
            loss = score_loss
        elif self.loss == 'pde':
            loss = torch.concatenate([
                score_loss,
                self.loss_weight[0] * ic_loss,
                self.loss_weight[1] * de_loss
            ], dim=-1)

        else:
            raise NotImplementedError()

        loss_dict['total'] = (loss_w * loss).mean()
        return loss_dict, loss.size(0)

    @torch.no_grad()
    def network_predict(self, xt: Tensor, t: Tensor, condition: Tensor, grid: Tensor) -> Tensor:
        model_predict = self.network(xt, t, condition, grid)
        if self.model_out == 'noise':
            return model_predict
        else:
            alpha, sigma = self.alpha_t[t].reshape(-1, *[1] * (xt.ndim - 1)), self.sigma_t[t].reshape(-1, *[1] * (xt.ndim - 1))
            return (xt - alpha * model_predict) / sigma
        
    @torch.no_grad()
    def ode_reverse(self, s: Tensor, t: Tensor, xt: Tensor, condition: Tensor, grid: Tensor) -> Tensor:
        eps_model = self.network_predict(xt, s.repeat(xt.size(0)), condition, grid)

        model_out = self.sigma_t[t] * torch.expm1(self.lambda_t[t] - self.lambda_t[s]) * eps_model
        return self.alpha_t[t] / self.alpha_t[s] * xt - model_out
    
    @torch.no_grad()
    def ddpm_reverse(self, s: Tensor, t: Tensor, xt: Tensor, condition: Tensor, grid: Tensor) -> Tensor:
        eps_model = self.network_predict(xt, s.repeat(xt.size(0)), condition, grid)
        eps_model = (1 - self.alpha_org[s]) / (1 - self.alpha_t[s]).sqrt() * eps_model

        mu = (xt - eps_model) / self.alpha_org[s].sqrt()
        sigma2 = (1 - self.alpha_t[t]) / (1 - self.alpha_t[s]) * (1 - self.alpha_org[s])
        # sigma2 = 1 - self.alpha_org[s]
        return mu + sigma2.sqrt() * torch.randn_like(mu)

    @torch.no_grad()
    def dpm2_reverse(self, s: Tensor, t: Tensor, xt: Tensor, condition: Tensor, grid: Tensor) -> Tensor:
        mid_t = torch.argmin(((self.lambda_t[s] + self.lambda_t[t]) / 2 - self.lambda_t).abs())
        h = self.lambda_t[t] - self.lambda_t[s]

        u = self.alpha_t[mid_t] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_t] * torch.expm1(h/2) * self.network_predict(xt, s.repeat(xt.size(0)), condition, grid)
        
        return self.alpha_t[t] / self.alpha_t[s] * xt \
            - self.sigma_t[t] * torch.expm1(h) * self.network_predict(u, mid_t.repeat(xt.size(0)), condition, grid)

    @torch.no_grad()
    def dpm3_reverse(self, s: Tensor, t: Tensor, xt: Tensor, condition: Tensor, grid: Tensor) -> Tensor:
        r1, r2, h = 1/3, 2/3, self.lambda_t[t] - self.lambda_t[s]
        mid_1 = torch.argmin((self.lambda_t[s] + r1 * h - self.lambda_t).abs())
        mid_2 = torch.argmin((self.lambda_t[s] + r2 * h - self.lambda_t).abs())

        eps_s = self.network_predict(xt, s.repeat(xt.size(0)), condition, grid)
        u1 = self.alpha_t[mid_1] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_1] * torch.expm1(r1 * h) * eps_s
        
        d1 = self.network_predict(u1, mid_1.repeat(xt.size(0)), condition, grid) - eps_s

        u2 = self.alpha_t[mid_2] / self.alpha_t[s] * xt \
            - self.sigma_t[mid_2] * torch.expm1(r2 * h) * eps_s \
            - self.sigma_t[mid_2] * r2 / r1 * (torch.expm1(r2 * h) / (r2 * h) - 1) * d1
        
        d2 = self.network_predict(u2, mid_2.repeat(xt.size(0)), condition, grid) - eps_s

        return self.alpha_t[t] / self.alpha_t[s] * xt \
            - self.sigma_t[t] * torch.expm1(h) * eps_s \
            - self.sigma_t[t] / r2 * (torch.expm1(h) / h - 1) * d2
    
    @torch.no_grad()
    def ld_reverse(self, s: Tensor, t: Tensor, xt: Tensor, condition: Tensor, grid: Tensor) -> Tensor:
        sigma_t = self.sigma_t[s]
        eps_model = self.network_predict(
            xt,
            torch.ones(size=(xt.size(0), ), device=xt.device).to(torch.long),
            condition,
            grid
        )
        step_size = 0.001 * sigma_t.square()
        xt = xt + step_size * eps_model / sigma_t + torch.sqrt(step_size*2) * torch.rand_like(xt)
        return xt
    
    @torch.no_grad()
    def sampling(self, data_shape: Tuple, condition: Tensor, grid: Tensor, method: str = None) -> Tensor:
        if method is None and self.config.diffusion == 'vp':
            method = self.sampling_cfg.method

        if self.config.diffusion == 've':
            method = 'ld'
            itertor = self.ld_reverse
            timestamps = torch.arange(0, self.steps)
        else:
            if method == 'ddpm':
                itertor = self.ddpm_reverse
                timestamps = torch.arange(-1, self.steps)
            elif method == 'ode':
                itertor = self.ode_reverse
                timestamps = torch.arange(0, self.steps)
            elif method == 'dpm3':
                itertor = self.dpm3_reverse
                timestamps = torch.linspace(0, self.steps - 1, 30)
            else:
                raise NotImplementedError('sampling method not found')
        
        timestamps = timestamps.__reversed__().int().to(self.device)

        if self.config.diffusion == 'vp':
            xt = torch.randn(size=data_shape, device=self.device)
        else:
            xt = self.sigma_t[-1] * torch.randn(size=data_shape, device=self.device)

        pbar = tqdm(zip(timestamps[:-1], timestamps[1:]), leave=False, total=len(timestamps)-1, dynamic_ncols=True)
        for s, t in pbar:
            xt = itertor(s, t, xt, condition, grid)
            
        return xt

    @torch.no_grad()
    def sampling_on_xt_vp_x0(self, data_shape: Tuple, condition: Tensor, grid: Tensor) -> Tuple[Tensor, Tensor]:
        timestamps = torch.arange(0, self.steps).__reversed__().int().to(self.device)
        xt = torch.randn(size=data_shape, device=self.device)

        pbar = tqdm(zip(timestamps[:-1], timestamps[1:]), leave=False, total=len(timestamps)-1, dynamic_ncols=True)
        x0_states_on_t = []
        for i, (s, t) in enumerate(pbar):
            batch_time = s.repeat(xt.size(0))

            model_predict_x0 = self.network(xt, batch_time, condition, grid)

            if i % 10 == 0:
                x0_states_on_t.append(model_predict_x0)

            alpha, sigma = self.alpha_t[batch_time].reshape(-1, *[1] * (xt.ndim - 1)), self.sigma_t[batch_time].reshape(-1, *[1] * (xt.ndim - 1))
            
            eps_model = (xt - alpha * model_predict_x0) / sigma

            model_out = self.sigma_t[t] * torch.expm1(self.lambda_t[t] - self.lambda_t[s]) * eps_model
            xt = self.alpha_t[t] / self.alpha_t[s] * xt - model_out
    
        x0_states_on_t.append(xt)
        return torch.stack(x0_states_on_t).permute(1, 0, 2, 3, 4), xt




        