from torch import nn
import torch
from diffusers.schedulers import DDPMScheduler
import math

def fourier_embedding(timesteps: torch.Tensor, dim, max_period=10000):
    r"""Create sinusoidal timestep embeddings.

    Args:
        timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
        dim (int): the dimension of the output.
        max_period (int): controls the minimum frequency of the embeddings.
    Returns:
        embedding (torch.Tensor): [N $\times$ dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
        device=timesteps.device
    )
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

class ConvBlock_cond(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, dropout=False): ##TODO: Allow different dropout rates.
        super().__init__()
        self.norm1 = nn.GroupNorm(1, in_channels)
        self.norm2 = nn.GroupNorm(1, out_channels)

        self.activation = nn.GELU()

        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1, padding_mode='circular')
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1, padding_mode='circular')

        if dropout:
            self.dropout = nn.Dropout(0.1)
        else:
            self.dropout = nn.Identity()

        if in_channels != out_channels:
            self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
        else:
            self.shortcut = nn.Identity()
        
        self.cond_emb = nn.Linear(time_emb_dim, out_channels * 2)

    def forward(self, x, t):
        h = self.norm1(x)
        h = self.activation(h)
        h = self.conv1(h)
        h = self.norm2(h)
        scale, shift = torch.chunk(self.cond_emb(t), 2, dim=1)
        h = h * (scale.unsqueeze(-1) + 1) + shift.unsqueeze(-1) 
        h = self.activation(h)
        h = self.dropout(h)
        h = self.conv2(h)
        return h + self.shortcut(x)
    
class Down(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        self.down = nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, padding_mode='circular')
        self.conv = ConvBlock_cond(in_channels, out_channels, time_emb_dim)
    
    def forward(self, x, t):
        x = self.down(x)
        return self.conv(x, t)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        self.up = nn.ConvTranspose1d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = ConvBlock_cond(in_channels, out_channels, time_emb_dim)

    def forward(self, x1, x2, t):
        x = self.up(x1)
        x = torch.cat([x2, x], dim=1)
        return self.conv(x, t)

class UNet1D_cond(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, depth=4):
        super().__init__()
        self.lift = nn.Conv1d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode='circular')
        self.time_embed_dim = hidden_channels*4
        self.hidden_channels = hidden_channels

        down = []
        for _ in range(depth):
            down.append(Down(hidden_channels, hidden_channels * 2, self.time_embed_dim))
            hidden_channels *= 2
        self.down = nn.ModuleList(down)

        up = []
        for _ in range(depth):
            up.append(Up(hidden_channels, hidden_channels // 2, self.time_embed_dim))
            hidden_channels //= 2
        self.up = nn.ModuleList(up)

        self.time_embed = nn.Sequential(
            nn.Linear(hidden_channels, self.time_embed_dim),
            nn.GELU(),
            nn.Linear(self.time_embed_dim, self.time_embed_dim),
        )
        self.fourier = fourier_embedding

        self.proj = nn.Conv1d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode='circular')

    def forward(self, x, t):
        x = self.lift(x)
        t = self.time_embed(self.fourier(t, self.hidden_channels))

        h = []

        for l in self.down:
            h.append(x)
            x = l(x,t)

        for l in self.up:
            x = l(x, h.pop(), t)

        return self.proj(x)
    
    

class PDERefiner(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, depth, num_steps, minstd,norm_res,residual):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.depth = depth
        self.num_steps = num_steps
        self.minstd = minstd
        self.unet = UNet1D_cond(in_channels, out_channels, hidden_channels, depth)
        self.norm_res = norm_res
        betas = [minstd ** (k / num_steps) for k in reversed(range(num_steps + 1))]
        self.scheduler = DDPMScheduler(
            num_train_timesteps=num_steps + 1,
            trained_betas=betas,
            prediction_type="v_prediction",
            clip_sample=False,
        )
        self.residual = residual
        self.time_multiplier = 1000 / num_steps

    def build_scheduler(self,num_steps):
        betas = [self.minstd ** (k / num_steps) for k in reversed(range(num_steps + 1))]
        scheduler = DDPMScheduler(
            num_train_timesteps=num_steps + 1,
            trained_betas=betas,
            prediction_type="v_prediction",
            clip_sample=False,
        )
        return scheduler


    def forward(self, x, y=None, num_steps=None):
        if self.training:

            k = torch.randint(0, self.scheduler.config.num_train_timesteps, (x.shape[0],), device=x.device)
            noise_factor = self.scheduler.alphas_cumprod.to(x.device)[k]
            noise_factor = noise_factor.view(-1, *[1 for _ in range(x.ndim - 1)])
            signal_factor = 1 - noise_factor
            noise = torch.randn_like(y)
            y_noised = self.scheduler.add_noise(y, noise, k)
            x_in = torch.cat([x, y_noised], axis=1)
            pred = self.unet(x_in, k * self.time_multiplier)
            target = (noise_factor**0.5) * noise - (signal_factor**0.5) * y
            loss = torch.nn.MSELoss()(pred, target)
            return loss

        else:
            num_steps = num_steps if num_steps is not None else self.num_steps
            y_noised = torch.randn(
            size=(x.shape[0], self.out_channels, x.shape[2]), dtype=x.dtype, device=x.device
            )
            self.scheduler = self.build_scheduler(num_steps)
             # Perform diffusion sampling
            #for k in range(num_steps, -1, -1):  # From num_steps down to 0 (descending order)
            for k in self.scheduler.timesteps:
                time = torch.zeros(size=(x.shape[0],), dtype=x.dtype, device=x.device) + k
                x_in = torch.cat([x, y_noised], axis=1)
                pred = self.unet(x_in, time * self.time_multiplier)
                y_noised = self.scheduler.step(pred, k, y_noised).prev_sample
            y = y_noised

            return y

class PDERefiner_2(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, depth, num_steps, minstd,norm_res,residual):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.depth = depth
        self.num_steps = num_steps
        self.minstd = minstd
        self.unet = UNet1D_cond(in_channels, out_channels, hidden_channels, depth)
        self.norm_res = norm_res
        betas = [minstd ** (k / num_steps) for k in reversed(range(num_steps + 1))]
        self.scheduler = DDPMScheduler(
            num_train_timesteps=num_steps + 1,
            trained_betas=betas,
            prediction_type="v_prediction",
            clip_sample=False,
        )
        self.residual = residual
        self.time_multiplier = 1000 / num_steps


    def forward(self, x, y=None, num_steps=None):
        if self.training:
            k = torch.randint(0, self.num_steps+1, (x.shape[0],),device=x.device,dtype=torch.long)
            
            std = self.minstd ** (k / self.num_steps)
            noise = torch.randn_like(x, device=x.device) 
            eps = noise * std.unsqueeze(-1).unsqueeze(-1)
            sol_noised = torch.where(k.unsqueeze(-1).unsqueeze(-1) == 0, torch.zeros_like(y,device=x.device), y + eps)
            
            x = torch.cat([x,sol_noised], dim = 1)
            pred = self.unet(x,k * self.time_multiplier)
            
            target = torch.where(k.unsqueeze(-1).unsqueeze(-1) == 0, y, noise)
            loss = torch.nn.MSELoss()(pred, target)
            return loss
            
        else:
            num_steps = num_steps if num_steps is not None else self.num_steps
            k = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long)
            pred_noiosed = torch.zeros_like(x,device=x.device)
            pred = self.unet(torch.cat([x,pred_noiosed], dim = 1),k * self.time_multiplier)
            for _ in range(1,num_steps+1):
                k += 1
                std = self.minstd ** (k / self.num_steps)
                noise = torch.randn_like(x, device=x.device) 
                pred_noiosed = pred + noise * std.unsqueeze(-1).unsqueeze(-1)
                pred_noise = self.unet(torch.cat([x,pred_noiosed], dim = 1),k * self.time_multiplier)
                pred = pred_noiosed - pred_noise*std.unsqueeze(-1).unsqueeze(-1)
            return pred