import torch
from torch import nn
import torch.nn.functional as F
from .dskd_modules import DiffusionModel, NoiseAdapter, AutoEncoder, DDIMPipeline
from .scheduling_ddim import DDIMScheduler

class DSKD(nn.Module):
    def __init__(
            self,
            student_channels,
            teacher_channels,
            kernel_size=3,
            inference_steps=5,
            num_train_timesteps=1000,
            use_ae=False,
            ae_channels=None,
            dclassifier = None,
    ):
        super().__init__()
        self.use_ae = use_ae
        self.diffusion_inference_steps = inference_steps
        if dclassifier is not None:
            self.noise_adapter_d = NoiseAdapter(teacher_channels, kernel_size)
            self.model_d = DiffusionModel(channels_in=teacher_channels, kernel_size=kernel_size)
            self.trans_1 = nn.Conv2d(student_channels, teacher_channels, 1)
        self.scheduler = DDIMScheduler(num_train_timesteps=num_train_timesteps, clip_sample=False, beta_schedule="linear")
        self.cond = False

        if dclassifier is not None:
            self.pipeline_guidance = DDIMPipeline(self.model_d, self.scheduler, self.noise_adapter_d, dclassifier=dclassifier)
            self.cond = True
        
        # self.proj = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels, 1), nn.BatchNorm2d(teacher_channels))
    def forward(self, student_feat, teacher_feat, target_class=None):
        # project student feature to the same dimension as teacher feature
        if self.cond:
            student_feat_ori = student_feat
            student_feat_ori = self.trans_1(student_feat_ori)
            teacher_feat_ori = teacher_feat.detach()
            ddim_loss_d = self.ddim_loss_d(teacher_feat_ori)

        refined_feat_guidance = None
        # denoise student feature
        if self.cond == True and target_class is not None:
            refined_feat_guidance = self.pipeline_guidance(
                batch_size=student_feat_ori.shape[0],
                device=student_feat_ori.device,
                dtype=student_feat_ori.dtype,
                shape=student_feat_ori.shape[1:],
                feat=student_feat_ori,
                num_inference_steps=self.diffusion_inference_steps,
                # proj=self.proj, 
                target_class=target_class
            )
        if self.cond == True:
            return student_feat, student_feat_ori, refined_feat_guidance, teacher_feat, ddim_loss_d
        else:
            return student_feat, refined_feat_guidance, teacher_feat

    def ddim_loss_d(self, gt_feat):
        # Sample noise to add to the images
        noise = torch.randn(gt_feat.shape, device=gt_feat.device) #.to(gt_feat.device)
        bs = gt_feat.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(0, self.scheduler.num_train_timesteps, (bs,), device=gt_feat.device).long()
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_images = self.scheduler.add_noise(gt_feat, noise, timesteps)
        noise_pred = self.model_d(noisy_images, timesteps)
        loss = F.mse_loss(noise_pred, noise)
        return loss  
    
    def ddim_loss(self, gt_feat):
        # Sample noise to add to the images
        noise = torch.randn(gt_feat.shape, device=gt_feat.device) #.to(gt_feat.device)
        bs = gt_feat.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(0, self.scheduler.num_train_timesteps, (bs,), device=gt_feat.device).long()
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_images = self.scheduler.add_noise(gt_feat, noise, timesteps)
        noise_pred = self.model(noisy_images, timesteps)
        loss = F.mse_loss(noise_pred, noise)
        return loss