from torch import nn
import torch
import math
from diffusers.schedulers import DDPMScheduler
import torch.nn.functional as F

def fourier_embedding(timesteps: torch.Tensor, dim, max_period=10000):
    r"""Create sinusoidal timestep embeddings.

    Args:
        timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
        dim (int): the dimension of the output.
        max_period (int): controls the minimum frequency of the embeddings.
    Returns:
        embedding (torch.Tensor): [N $\times$ dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
        device=timesteps.device
    )
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

class ConvBlock_cond(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, dropout=False): ##TODO: Allow different dropout rates.
        super().__init__()
        self.norm1 = nn.GroupNorm(1, in_channels)
        self.norm2 = nn.GroupNorm(1, out_channels)

        self.activation = nn.GELU()

        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1, padding_mode='circular')
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1, padding_mode='circular')

        if dropout:
            self.dropout = nn.Dropout(0.1)
        else:
            self.dropout = nn.Identity()

        if in_channels != out_channels:
            self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
        else:
            self.shortcut = nn.Identity()
        
        self.cond_emb = nn.Linear(time_emb_dim, out_channels * 2)

    def forward(self, x, t):
        h = self.norm1(x)
        h = self.activation(h)
        h = self.conv1(h)
        h = self.norm2(h)
        scale, shift = torch.chunk(self.cond_emb(t), 2, dim=1)
        h = h * (scale.unsqueeze(-1) + 1) + shift.unsqueeze(-1) 
        h = self.activation(h)
        h = self.dropout(h)
        h = self.conv2(h)
        return h + self.shortcut(x)
    
class Down(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        self.down = nn.Conv1d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, padding_mode='circular')
        self.conv = ConvBlock_cond(in_channels, out_channels, time_emb_dim)
    
    def forward(self, x, t):
        x = self.down(x)
        return self.conv(x, t)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        self.up = nn.ConvTranspose1d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = ConvBlock_cond(in_channels, out_channels, time_emb_dim)

    def forward(self, x1, x2, t):
        x = self.up(x1)
        x = torch.cat([x2, x], dim=1)
        return self.conv(x, t)

class UNet1D_cond(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, depth=4):
        super().__init__()
        self.lift = nn.Conv1d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode='circular')
        self.time_embed_dim = hidden_channels*4
        self.hidden_channels = hidden_channels

        down = []
        for _ in range(depth):
            down.append(Down(hidden_channels, hidden_channels * 2, self.time_embed_dim))
            hidden_channels *= 2
        self.down = nn.ModuleList(down)

        up = []
        for _ in range(depth):
            up.append(Up(hidden_channels, hidden_channels // 2, self.time_embed_dim))
            hidden_channels //= 2
        self.up = nn.ModuleList(up)

        self.time_embed = nn.Sequential(
            nn.Linear(hidden_channels, self.time_embed_dim),
            nn.GELU(),
            nn.Linear(self.time_embed_dim, self.time_embed_dim),
        )
        self.fourier = fourier_embedding

        self.proj = nn.Conv1d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode='circular')

    def forward(self, x, t):
        x = self.lift(x)
        t = self.time_embed(self.fourier(t, self.hidden_channels))

        h = []

        for l in self.down:
            h.append(x)
            x = l(x,t)

        for l in self.up:
            x = l(x, h.pop(), t)

        return self.proj(x)



class DiffusionModel(nn.Module):
    def __init__(self, diffusionSteps:int, in_channels:int, out_channels:int, hidden_channels:int, depth:int):
        super().__init__()

        self.timesteps = diffusionSteps
        self.scheduler = DDPMScheduler(
            beta_schedule="linear",
            num_train_timesteps = self.timesteps,
            prediction_type="epsilon",
            clip_sample=False,

        )
        self.time_multiplier = 1000 / diffusionSteps
        # backbone model
        self.unet = UNet1D_cond(
            in_channels=in_channels,
            out_channels=out_channels,
            hidden_channels=hidden_channels,
            depth=depth,
        )


    def forward(self, x:torch.Tensor, y:torch.Tensor=None, num_steps:int=None) -> torch.Tensor:

            if self.training:
                noise = torch.randn_like(y)
                t = torch.randint(0, self.timesteps, (x.shape[0],), device=x.device).long()
                dNoisy = self.scheduler.add_noise(y, noise, t)
                dNoisy = torch.cat([x,dNoisy], dim = 1)
                predictedNoise = self.unet(dNoisy, t * self.time_multiplier)
                loss = torch.nn.MSELoss()(predictedNoise, noise)
                return loss

            else:
                num_steps = num_steps if num_steps is not None else 50
                self.scheduler.set_timesteps(num_steps)
                dNoise = torch.randn_like(x)
                for time in self.scheduler.timesteps:
                    time_tensor = torch.full((x.shape[0],), time, device=x.device, dtype=torch.long)
                    predictedNoise = self.unet(torch.cat([x, dNoise], dim=1), time_tensor * self.time_multiplier)
                    dNoise = self.scheduler.step(predictedNoise, time, dNoise).prev_sample

                return dNoise


