import os
import numpy as np
import torch
import random
from PIL import Image
    
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    
def mkdir_ifnotexists(directory):
    if not os.path.exists(directory):
        os.mkdir(directory)



def concat_home_dir(path):
    return os.path.join(os.environ['HOME'],'data',path)


def get_class(kls):
    parts = kls.split('.')
    module = ".".join(parts[:-1])
    m = __import__(module)
    for comp in parts[1:]:
        m = getattr(m, comp)
    return m


def to_cuda(torch_obj):
    if torch.cuda.is_available():
        return torch_obj.cuda()
    else:
        return torch_obj


def load_data(file_name):

    ext = file_name.split('.')[-1]

    if ext == "npz" or ext == "npy":
        pnts = torch.tensor(np.load(file_name)).float()
    else:
        raise ValueError("Data is not npz or npy!")
    
    return pnts
        

class LearningRateSchedule:
    def get_learning_rate(self, epoch):
        pass


class StepLearningRateSchedule(LearningRateSchedule):
    def __init__(self, initial, interval, factor):
        self.initial = initial
        self.interval = interval
        self.factor = factor

    def get_learning_rate(self, epoch):
        return np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6)


def load_rgb(img_path):
    # output rgb values : rescaled to [0,1]^3
    image = Image.open(img_path)
    image = image.convert("RGB") #[0,255]^3
    width, height = image.size
    rgb_values = []
    for y in range(height):
        for x in range(width):
            r, g, b = image.getpixel((x, y))
            rgb_values.append((r, g, b))
    return np.array(rgb_values)/255, width, height