from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
import pickle
from einops import rearrange


class RamImage():
    def __init__(self, path, load):
        
        if load:
            fd = open(path, 'rb')
            img_str = fd.read()
            fd.close()

            self.img_raw = np.frombuffer(img_str, np.uint8)
        else:
            self.img_raw = None
            self.path    = path

    def to_numpy(self):
        if self.img_raw is not None:
            return cv2.imdecode(self.img_raw, cv2.IMREAD_COLOR) 

        return cv2.imread(self.path)

class AsteroidsSample(data.Dataset):
    def __init__(self, root_path: str, data_path: str, size: Tuple[int, int], length):

        data_path = os.path.join(root_path, data_path, f'{size[0]}x{size[1]}')

        rgb   = []
        depth = []
        self.size = size
        self.length = length

        for file in os.listdir(data_path):
            if file.startswith("rgb") and file.endswith(".jpg"):
                rgb.append(os.path.join(data_path, file))
            if file.startswith("depth") and file.endswith(".jpg"):
                depth.append(os.path.join(data_path, file))

        rgb.sort()
        depth.sort()
        self.rgb   = []
        self.depth = []
        for path in rgb:
            self.rgb.append(RamImage(path, size[0]==64))
        for path in depth:
            self.depth.append(RamImage(path, size[0]==64))

    def get_data(self):

        rgb   = np.zeros((self.length,3,self.size[1], self.size[0]),dtype=np.float32)
        depth = np.zeros((self.length,1,self.size[1], self.size[0]),dtype=np.float32)

        start = np.random.randint(1000 - self.length + 1)
        for i in range(start, start + self.length):
            _rgb   = self.rgb[i].to_numpy()
            _depth = np.flip(self.depth[i].to_numpy(), axis=0)
            rgb[i-start]   = _rgb.transpose(2, 0, 1).astype(np.float32) / 255.0
            depth[i-start] = np.mean(_depth.transpose(2, 0, 1).astype(np.float32) / 255.0, axis = 0, keepdims=True)

        return rgb, depth * -1 + 1


class AsteroidsDataset(data.Dataset):

    def save(self):
        with open(self.file, "wb") as outfile:
    	    pickle.dump(self.samples, outfile)

    def load(self):
        with open(self.file, "rb") as infile:
            self.samples = pickle.load(infile)

    def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], length):

        data_path  = f'data/data/video/{dataset_name}'
        data_path  = os.path.join(root_path, data_path)
        self.file  = os.path.join(data_path, f'dataset-{size[0]}x{size[1]}.pickle')
        self.train = (type == "train")

        self.background = cv2.imread(os.path.join(data_path, f'background{size[0]}x{size[1]}.jpg')) / 255.0
        self.background = rearrange(self.background.astype(np.float32), 'h w c -> 1 c h w') * 0.5

        self.samples = []

        if os.path.exists(self.file):
            self.load()
        else:

            samples     = list(filter(lambda x: x.startswith("0"), next(os.walk(data_path))[1]))
            num_samples = len(samples)

            for i, dir in enumerate(samples):
                self.samples.append(AsteroidsSample(data_path, dir, size, length))

                print(f"Loading ASTEROIDS [{i * 100 / num_samples:.2f}]", flush=True)

            self.save()
        
        self.length = len(self.samples)

        print(f"AsteroidsDataset: {self.length}")

        if len(self) == 0:
            raise FileNotFoundError(f'Found no dataset at {self.data_path}')

    def __len__(self):
        if self.train:
            return int(self.length * 0.9)

        return int(self.length * 0.1)

    def __getitem__(self, index: int):

        if not self.train:
            index += int(self.length * 0.9)
        
        rgb, depth = self.samples[index].get_data()

        mask = 1/(1 + np.exp(-rgb * 10 + 5))
        rgb = self.background * (1 - mask) + rgb * mask

        return rgb, depth, self.background
