from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn

import accelerate

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers import DiffusionPipeline, ImagePipelineOutput
import dnnlib
import math
import torchmetrics

class SequentialSampler(torch.utils.data.Sampler):
    def __init__(
        self,
        dataset: torch.utils.data.Dataset,
        num_samples: Optional[int] = None,
        shuffle: bool = False,
    ):
        super().__init__(dataset)
        self.dataset = dataset
        self.num_samples = len(dataset) if num_samples is None else num_samples
        self.shuffle = shuffle

    def __iter__(self):
        # Synchronize Random Generator across devices
        generator = torch.Generator()
        accelerate.synchronize_rng_states(["generator"], generator=generator)

        # Iterate
        idx = 0
        order = torch.arange(len(self.dataset))
        while idx < self.num_samples:
            if idx % len(self.dataset) == 0 and self.shuffle:
                order = torch.randperm(len(self.dataset), generator=generator)

            yield order[idx % len(self.dataset)].item()
            idx += 1

    def __len__(self):
        return self.num_samples


class CMPipeline(DiffusionPipeline):
    def __init__(self, unet, scheduler):
        super().__init__()
        self.register_modules(unet=unet, scheduler=scheduler)

    @torch.no_grad()
    def __call__(
        self,
        batch_size: int = 1,
        generator: Optional[torch.Generator] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
    ) -> Union[ImagePipelineOutput, Tuple]:

        image_shape = (batch_size, ) + self.unet.img_shape
        t = torch.full((batch_size,), self.scheduler.sigma_max, dtype=torch.float, device=self.device)
        x = torch.randn(image_shape, generator=generator, device=self.device).mul_(t.view(-1, 1, 1, 1))
        images = self.scheduler.step(self.unet, x, t)
        images = (images * 0.5 + 0.5).clamp(0, 1)
        if output_type == "tensor":
            return images

        images = images.cpu().permute(0, 2, 3, 1).numpy()
        if output_type == "pil":
            images = self.numpy_to_pil(images)

        if not return_dict:
            return (images, )

        return ImagePipelineOutput(images=images)


class CMPipeline4(DiffusionPipeline):
    def __init__(self, unet, scheduler):
        super().__init__()
        self.register_modules(unet=unet, scheduler=scheduler)
        self.number = 0

        
    @torch.no_grad()

    def __call__(self, batch_size=1, generator=None, output_type="pil", return_dict: bool=True) -> Union[ImagePipelineOutput, Tuple]:
        epsilon = torch.full((batch_size,), 2e-3, dtype= torch.float, device=self.device)
        image_shape = (batch_size, ) + self.unet.img_shape
    
        timesteps = self.scheduler.discretize_timesteps(18, device=self.device)

        
        t0 = torch.full((batch_size,), timesteps[17], dtype=torch.float, device = self.device)
        #t1 = torch.full((batch_size,), timesteps[self.number], dtype=torch.float, device = self.device)
        noise = torch.randn(image_shape, generator = generator, device=self.device) 

        #noise2 = torch.randn(image_shape, generator = generator, device = self.device)
        x0 = noise * (t0.view(-1,1,1,1))
        image0 = self.scheduler.step(self.unet, x0, t0)

        
        epsilon = torch.full((batch_size,), timesteps[0] , dtype=torch.float, device = self.device)


        for i_ in range(0):
            reali = 16-i_
            t1 = torch.full((batch_size,), timesteps[reali], dtype=torch.float, device = self.device)

            sub = (t1 ** 2 - epsilon ** 2).sqrt()

            noise = torch.randn(image_shape, generator = generator, device = self.device)
            x1 = image0 + noise * (sub.view(-1,1,1,1))

            image0 = self.scheduler.step(self.unet, x1, t1)
            
        image3 = (image0 * 0.5 + 0.5).clamp(0,1)
    
        if output_type == 'tensor':
            return image3
        image3 = image3.cpu().permute(0,2,3,1).numpy()
        if output_type == "pil":
            image3 = self.numpy_to_pil(image3)

        if not return_dict:
            return (image3, ) 

        return ImagePipelineOutput(images = image3)


class CMScheduler(SchedulerMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        sigma_min:  float = 0.002,
        sigma_max:  float = 80,
        sigma_data: float = 0.5,
    ):
        self.sigma_max  = sigma_max
        self.sigma_min  = sigma_min
        self.sigma_data = sigma_data

    def discretize_timesteps(
        self,
        num_timesteps: int,
        rho: float = 7.,
        device: Union[str, torch.device] = None
    ) -> torch.FloatTensor:
        return torch.linspace(self.sigma_min ** (1 / rho),
                              self.sigma_max ** (1 / rho),
                              steps=num_timesteps,
                              device=device) ** rho

    def add_noise(self, x: torch.FloatTensor, z: torch.FloatTensor, t: torch.FloatTensor) -> torch.FloatTensor:
        return x + z * t.view(-1, 1, 1, 1)

    def step_dm(self,model,x,t):
        t = t.view(-1, 1, 1, 1)
        c_skip = self.sigma_data**2 / ( t **2 + self.sigma_data**2)
        c_out =  self.sigma_data * t / (self.sigma_data**2 + t**2).sqrt()
        c_in = 1. / (t**2 + self.sigma_data**2).sqrt()
        c_noise = t.log() / 4.
        return c_skip * x + c_out * model(c_in * x, c_noise.flatten(), None)

    def step(self, model: nn.Module, x: torch.FloatTensor, t: torch.FloatTensor) -> torch.FloatTensor:
        t = t.view(-1, 1, 1, 1)
        c_skip = self.sigma_data**2 / ((t - self.sigma_min)**2 + self.sigma_data**2)
        c_out =  self.sigma_data * (t-self.sigma_min) / (self.sigma_data**2 + t**2).sqrt()
        c_in = 1. / (t**2 + self.sigma_data**2).sqrt()
        c_noise = t.log() / 4.
        return c_skip * x + c_out * model(c_in * x, c_noise.flatten(), None)

    def make_latent(self, model: nn.Module, x: torch.FloatTensor, t: torch.FloatTensor) -> torch.FloatTensor:

        t= t.view(-1,1,1,1)

        c_in = 1. / (t ** 2 + self.sigma_data ** 2).sqrt()
        c_noise = t.log() / 4.

        return model.latent(c_in * x, c_noise.flatten(), None)
    

def ema_init(online):
    target = online.__class__.from_config(online.config)
    target.requires_grad_(False)
    online_params  = dict(online.named_parameters())
    target_params  = dict(target.named_parameters())
    online_buffers = dict(online.named_buffers())
    target_buffers = dict(target.named_buffers())
    for k in target_params.keys():
        target_params[k].data.copy_(online_params[k].data)
    for k in target_buffers.keys():
        target_buffers[k].data.copy_(online_buffers[k].data)
    return target


def ema_update(online, target, decay_rate):
    online_params  = dict(online.named_parameters())
    target_params  = dict(target.named_parameters())
    online_buffers = dict(online.named_buffers())
    target_buffers = dict(target.named_buffers())
    for k in target_params.keys():
        target_params[k].data.mul_(decay_rate).add_(online_params[k].data, alpha=1-decay_rate)
    for k in target_buffers.keys():
        target_buffers[k].data.copy_(online_buffers[k].data)



@torch.no_grad()
def copy_params_and_buffers(src_module, dst_module, require_all=False):
    assert isinstance(src_module, torch.nn.Module)
    assert isinstance(dst_module, torch.nn.Module)
    src_tensors = dict(named_params_and_buffers(src_module))
    for name, tensor in named_params_and_buffers(dst_module):
        assert (name in src_tensors) or (not require_all)
        if name in src_tensors:
            tensor.copy_(src_tensors[name])

def params_and_buffers(module):
    assert isinstance(module, torch.nn.Module)
    return list(module.parameters()) + list(module.buffers())

def named_params_and_buffers(module):
    assert isinstance(module, torch.nn.Module)
    return list(module.named_parameters()) + list(module.named_buffers())



