from PIL import Image
import argparse
import numpy as np
import torch
import random
from math import log10, sqrt


def ensure_reproducibility(seed):
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)



def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x):
        return 1 / (np.exp(-x) + 1)

    if beta_schedule == "quad":
        betas = (
            np.linspace(
                beta_start ** 0.5,
                beta_end ** 0.5,
                num_diffusion_timesteps,
                dtype=np.float64,
            )
            ** 2
        )
    elif beta_schedule == "linear":
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "const":
        betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "jsd":  
        betas = 1.0 / np.linspace(
            num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas


# Code form RePaint   
def get_schedule_jump(T_sampling, travel_length, travel_repeat):
    jumps = {}
    for j in range(0, T_sampling - travel_length, travel_length):
        jumps[j] = travel_repeat - 1

    t = T_sampling
    ts = []

    while t >= 1:
        t = t-1
        ts.append(t)

        if jumps.get(t, 0) > 0:
            jumps[t] = jumps[t] - 1
            for _ in range(travel_length):
                t = t + 1
                ts.append(t)

    ts.append(-1)

    _check_times(ts, -1, T_sampling)
    return ts

def _check_times(times, t_0, T_sampling):
    # Check end
    assert times[0] > times[1], (times[0], times[1])

    # Check beginning
    assert times[-1] == -1, times[-1]

    # Steplength = 1
    for t_last, t_cur in zip(times[:-1], times[1:]):
        assert abs(t_last - t_cur) == 1, (t_last, t_cur)

    # Value range
    for t in times:
        assert t >= t_0, (t, t_0)
        assert t <= T_sampling, (t, T_sampling)
        
def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a



def process(sample, i):
    image_processed = sample.detach().cpu().permute(0, 2, 3, 1)
    image_processed = image_processed.squeeze(0)
    image_processed = torch.clip((image_processed + 1.0) * 127.5, 0., 255.)
    image_processed = image_processed.numpy().astype(np.uint8)
    return image_processed

def process_gray(sample, i):
    image_processed = sample.detach().cpu().permute(0, 2, 3, 1)
    image_processed = image_processed.squeeze(0)
    image_processed = torch.clip((image_processed + 1.0) * 127.5, 0., 255.)
    image_processed = image_processed.numpy().astype(np.uint8)
    init_image=Image.fromarray(image_processed).convert('L') 
    img1 = np.expand_dims(np.array(init_image).astype(np.uint8), axis=2)
    #mask1 = np.where(img1 < 50)
    img2 = np.tile(img1, [1, 1, 3])   
    return img2

def process_gray_thresh(sample, i, thresh=170):
    image_processed = sample.detach().cpu().permute(0, 2, 3, 1)
    image_processed = image_processed.squeeze(0)
    image_processed = torch.clip((image_processed + 1.0) * 127.5, 0., 255.)
    image_processed = image_processed.numpy().astype(np.uint8)
    init_image=Image.fromarray(image_processed).convert('L') 
    img1 = np.expand_dims(np.array(init_image).astype(np.uint8), axis=2)
    img3 = (np.where((img1 > thresh), 255, 0)).astype(np.uint8)
    img2 = np.tile(img3, [1, 1, 3])   
   
    return img2


def get_mask(sample, i, thresh=170):
    image_processed = sample.detach().cpu().permute(0, 2, 3, 1)
    image_processed = image_processed.squeeze(0)
    image_processed = torch.clip((image_processed + 1.0) * 127.5, 0., 255.)
    image_processed = image_processed.numpy().astype(np.uint8)
    init_image=Image.fromarray(image_processed).convert('L') 
    img1 = np.expand_dims(np.array(init_image).astype(np.uint8), axis=2)
    img3 = (np.where((img1 > thresh), 0., 1.)).astype(np.float32)
   
   
    return img3[:,:,0]



def psnr_orig(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                # Therefore PSNR have no importance.
        return 100
    max_pixel = 255.0
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr

def psnr_mask(original, compressed, mask):
    mse = ((original*mask - compressed*mask) ** 2).sum() / mask.sum()
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                # Therefore PSNR have no importance.
        return 100
    max_pixel = 255.0
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr

