import torch
import math
import os
import torch.nn as nn
import torch.nn.functional as F
from algorithms.SD import SD
import torchvision


class Sinkhorn(object):
    def __init__(self, opts, tgt_support, **kw) -> None:
        super().__init__()
        self.opts = opts
        self.device = opts.device
        self.tgt_support = tgt_support
        self.tgt_mass = torch.ones(opts.batch_size, device = opts.device) / opts.batch_size
        self.init_support = torch.rand(self.opts.batch_size, self.tgt_support.shape[1], device = self.device)
        self.init_mass = torch.ones(opts.batch_size, device = opts.device) / opts.batch_size
        self.record_sinkdiv = []
        self.record_support = []

        # check diffusion
        # self.betas = torch.linspace(1e-4, 0.02, self.opts.T).double().to(self.device)
        # alphas = 1. - self.betas
        # alphas_bar = torch.cumprod(alphas, dim=0)
        # self.sqrt_alphas_bar = torch.sqrt(alphas_bar)
        # self.sqrt_one_minus_alphas_bar = torch.sqrt(1. - alphas_bar)

    def forward(self):
        algorithm = SD(opts = self.opts, init_particles = self.init_support, init_mass = self.init_mass)

        # Interpolation check
        # for step in range(self.opts.T):
        #     if step % 10 == 0 or (step + 1) == self.opts.T:
        #         t = step / self.opts.T
        #         image = t * self.tgt_support + (1 - t) * self.init_support
        #         image = convert2image(image)
        #         torchvision.utils.save_image(image, os.path.join('./check_interpolation', '%s_%d.png'%('ck_itp', step)))
        
        # Diffusion model check beta_1 = 1e-4 beta_T = 0.02
        # for step in range(self.opts.T):
        #     if step % 10 == 0 or (step + 1) == self.opts.T:
        #         noise = torch.randn_like(self.tgt_support).to(self.device)
        #         t = torch.ones(self.tgt_support.shape[0]) * step
        #         t = t.type(torch.int64).to(self.device)
        #         image = (
        #         extract(self.sqrt_alphas_bar, t, self.tgt_support.shape) * self.tgt_support +
        #         extract(self.sqrt_one_minus_alphas_bar, t, self.tgt_support.shape) * noise)
        #         image = convert2image(image)
        #         torchvision.utils.save_image(image, os.path.join('./check_diffusion', '%s_%d.png'%('ck_diff', step)))

        for step in range(self.opts.T):
            # use Index sd_lr * exp((t - T) / (T / 4))  
            # lr = self.opts.SD_lr * math.exp((step - self.opts.T) / (self.opts.T / 4))
            lr = self.opts.SD_lr
            algorithm.one_step_update(
                step_size = lr,
                tgt_support = self.tgt_support,
                tgt_mass = self.tgt_mass
            )
            support, _, vector = algorithm.get_state()
            # check image 
            # if step % 10 == 0 or (step + 1) == self.opts.T:
            #     image = convert2image(support)
            #     torchvision.utils.save_image(image, os.path.join('./check_Sinkhorn', '%s_%d.png'%('ck_Sink', step)))
            self.record_sinkdiv.append(vector)   #[time, batch_size, 3*32*32]
            self.record_support.append(support)  #[time, batch_size, 3*32*32]
        
        algorithm.SD_clear_all()

    
    @torch.no_grad()
    def get_state(self):
        return torch.cat(self.record_sinkdiv).detach(), torch.cat(self.record_support).detach()
    
    @torch.no_grad()
    def sinkhorn_clear_all(self):
        self.record_sinkdiv = []
        self.record_support = []
    
@torch.no_grad()
def convert2image(support, num = 8):
    image_size = (32, 32)
    support = support[:num ** 2].view(num ** 2, 3, image_size[0], image_size[1])
    images = torchvision.utils.make_grid(support, nrow = num, value_range = (0, 1), normalize = True)
    return images
        
