import math
import torch
from torch import nn
import torch.nn.functional as F
from inspect import isfunction
from thop import profile

def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)

class ResBlock(nn.Module):
    def __init__(
        self, conv, n_feats, kernel_size,
        bias=True, bn=False, act=nn.LeakyReLU(0.1, inplace=True), res_scale=1):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        # self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x)
        res += x
        return res
    
class LE_arch(nn.Module):
    def __init__(self,n_feats = 64, n_encoder_res = 6, bn = False):
        super(LE_arch, self).__init__()
        E1=[nn.Conv2d(96, n_feats, kernel_size=3, padding=1),
            nn.LeakyReLU(0.1, True)]
        E2=[
            ResBlock(
                default_conv, n_feats, kernel_size=3, bn=bn
            ) for _ in range(n_encoder_res)
        ]
        E3=[
            nn.Conv2d(n_feats, n_feats * 2, kernel_size=3, padding=1),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(n_feats * 2, n_feats * 2, kernel_size=3, padding=1),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(n_feats * 2, n_feats * 4, kernel_size=3, padding=1),
            nn.LeakyReLU(0.1, True),
            nn.AdaptiveAvgPool2d(1),
        ]
        E=E1+E2+E3
        self.E = nn.Sequential(
            *E
        )
        self.mlp = nn.Sequential(
            nn.Linear(n_feats * 4, n_feats * 4),
            nn.LeakyReLU(0.1, True),
            nn.Linear(n_feats * 4, n_feats * 4),
            nn.LeakyReLU(0.1, True)
        )
        
        self.pixel_unshuffle = nn.PixelUnshuffle(4)
    def forward(self, x,gt):
        gt0 = self.pixel_unshuffle(gt)
        x0 = self.pixel_unshuffle(x)
        x = torch.cat([x0, gt0], dim=1)
        fea = self.E(x).squeeze(-1).squeeze(-1)
        fea1 = self.mlp(fea)
        return fea1
    
class ResMLP(nn.Module):
    def __init__(self,n_feats = 512):
        super(ResMLP, self).__init__()
        self.resmlp = nn.Sequential(
            nn.Linear(n_feats , n_feats ),
            nn.LeakyReLU(0.1, True),
        )
    def forward(self, x):
        res=self.resmlp(x)
        return res
    
class denoise(nn.Module):
    def __init__(self,n_feats = 64, n_denoise_res = 5,timesteps=5):
        super(denoise, self).__init__()
        # self.max_period=timesteps*10
        self.max_period=timesteps
        n_featsx4=4*n_feats
        resmlp = [
            nn.Linear(n_featsx4*2+1, n_featsx4),
            nn.LeakyReLU(0.1, True),
        ]
        for _ in range(n_denoise_res):
            resmlp.append(ResMLP(n_featsx4))
        self.resmlp=nn.Sequential(*resmlp)

    def forward(self,x, t,c):
        t=t.float()
        t =t/self.max_period
        t=t.view(-1,1)
        c = torch.cat([c,t,x],dim=1)
        fea = self.resmlp(c)

        return fea

class LatentExposureDiffusion(nn.Module):
    def __init__(self, total_timestamps=5, spvised_mid_out=False):
        super().__init__()

        self.lcr_model = denoise(timesteps = total_timestamps)
        self.noise_model = denoise(timesteps = total_timestamps)
        self.condition_encoder = LE_arch()

        self.total_timestamps = total_timestamps
        self.spvised_mid_out = spvised_mid_out

        # self.r = 0.2
        beta_start = 0.0
        beta_end = 0.02 # default 0.02
        alpha_start = 1.0
        alpha_end = 2.0
        betas = torch.linspace(beta_start, beta_end, self.total_timestamps, dtype=torch.float32)
        alphas = torch.linspace(alpha_start, alpha_end, self.total_timestamps+1, dtype=torch.float32)
        betas_bar_list = self.get_beta_bar(alphas, betas)
        time_stamps_list = torch.tensor([torch.tensor(i) for i in range(self.total_timestamps, 0, -1)])
        
        self.register_buffer("alphas", alphas)
        self.register_buffer("betas_bar", betas_bar_list)
        self.register_buffer("time_stamps_list", time_stamps_list)

    def get_beta_bar(self, alphas, betas):
        betas_bar_list = []
        for t in range(1, self.total_timestamps+1):
            sub_betas = betas[:t]  # 取前 t 个 beta 值
            weights = torch.tensor([(alphas[i-1] / alphas[t])**2 for i in range(1, t+1)], dtype=torch.float32)
            result = torch.sum(weights * sub_betas)
            betas_bar_list.append(result.clone().detach().sqrt())
        return torch.tensor(betas_bar_list)
    
    def q_sample_d(self, img):
        noise = torch.randn_like(img)
        return img + self.betas_bar[self.total_timestamps-1] * noise, noise

    def forward(self, blur):
        pred_lcr_list = []
        pred_noise_list = []
        device = self.alphas.device
        b = blur.shape[0]
        T_z = self.condition_encoder(blur, blur)
        noise_img, noise = self.q_sample_d(T_z)
        for i in self.time_stamps_list:
            # t = i.unsqueeze(0)
            t = torch.full((b,), i,  device=device, dtype=torch.long)
            # print(t.shape, T_z.shape, noise_img.shape)
            pred_noise = self.noise_model(noise_img, t, T_z)
            pred_lcr = self.lcr_model(noise_img, t, T_z)

            if self.spvised_mid_out:
                pred_lcr_list.append(pred_lcr)
                pred_noise_list.append(pred_noise)

            if i == 1:
                noise_cof = self.betas_bar[i - 1]
            else: 
                beta_t_bar = self.betas_bar[i - 1]
                beta_t_minus1_bar = self.betas_bar[i - 2]
                noise_cof = (self.alphas[i]*beta_t_bar)/self.alphas[i-1] - beta_t_minus1_bar

            noise_img = ((self.alphas[i]*(noise_img) - pred_lcr)/self.alphas[i-1]) - noise_cof * pred_noise

        if self.spvised_mid_out:
            return noise_img, pred_lcr_list, pred_noise_list, noise
        else:
            return noise_img

if __name__ == '__main__':
    net = LatentExposureDiffusion(
    ).cuda()
    
    input = torch.randn((1, 3, 256, 256)).cuda()
    flops, params = profile(net, (input,))
    print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')

    out = net(input)
    print(out.shape)