from argparse import ArgumentParser
from pathlib import Path
import random

import pytorch_lightning as pl
import torch
import numpy as np
from PIL import Image
from pl_modules.utils import load_config_from_yaml

DATA_CONFIG_PATH = 'configs/data_configs/dataset_config.yaml'


def load_ids_from_txt(path):
    with open(path) as f:
        ids = list(f)
    ids = [l.rstrip('\n') for l in ids]
    return ids

def get_fname(full_path):
    full_path = Path(full_path)
    return str(full_path).replace(str(full_path.parents[1]), '').replace(str(full_path.suffix), '')
    
class CelebaDataset(torch.utils.data.Dataset):

    def __init__(
        self,
        root,
        split,
        transform,
        sample_rate=None,
        train_val_seed=0
    ):

        self.transform = transform
        self.examples = []

        # set default sampling mode if none given
        if sample_rate is None:
            sample_rate = 1.0

        for file in list(Path(root).iterdir()):
            if file.suffix in ['.JPG', '.JPEG', '.jpg', '.jpeg']:
                suffix = file.suffix
                break
        
        data_config = load_config_from_yaml(DATA_CONFIG_PATH)
        img_ids = load_ids_from_txt(data_config['celeba256'][split+'_split'])
        for i in img_ids:
            file = Path(root)/(i + suffix)
            if file.is_file():
                self.examples.append(file)
            else:
                raise ValueError("CelebA image {} is missing.".format(i))

        # subsample if desired
        if sample_rate < 1.0: 
            random.shuffle(self.examples)
            num_examples = round(len(self.examples) * sample_rate)
            self.examples = self.examples[:num_examples]
            
        print('{} images loaded from {} as {} split.'.format(len(self.examples), str(root), split))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i: int):
        file = self.examples[i]
        fname = get_fname(file)
        im = Image.open(file).convert("RGB") # Will load grayscale images as RGB!

        if self.transform is None:
            raise ValueError('Must define forward model and pass in DataTransform.')
        else:
            sample = self.transform(im, fname)
            
        return sample
    
class FFHQDataset(torch.utils.data.Dataset):

    def __init__(
        self,
        root,
        split,
        transform,
        sample_rate=None,
    ):

        self.transform = transform
        self.examples = []

        # set default sampling mode if none given
        if sample_rate is None:
            sample_rate = 1.0
            
        suffix = '.png'
            
        data_config = load_config_from_yaml(DATA_CONFIG_PATH)
        img_ids = load_ids_from_txt(data_config['ffhq'][split+'_split'])
        
        for i in img_ids:
            folder = str(int(i) // 1000).rjust(5, '0')
            file = Path(root)/folder/('img' + i.rjust(8, '0') + suffix)
            if file.is_file():
                self.examples.append(file)
            else:
                raise ValueError("FFHQ image {} is missing.".format(i))
        
        # subsample if desired
        if sample_rate < 1.0: 
            random.shuffle(self.examples)
            num_examples = round(len(self.examples) * sample_rate)
            self.examples = self.examples[:num_examples]
            
        print('{} images loaded from {} as {} split.'.format(len(self.examples), str(root), split))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i: int):
        file = self.examples[i]
        fname = str(file).replace(str(file.parents[1]), '').replace(str(file.suffix), '')
        im = Image.open(file).convert("RGB") # Will load grayscale images as RGB!

        if self.transform is None:
            raise ValueError('Must define forward model and pass in DataTransform.')
        else:
            sample = self.transform(im, fname)
            
        return sample
    
    def get_filenames(self):
        filenames = [str(file).replace(str(file.parents[1]), '').replace(str(file.suffix), '') for file in self.examples]
        return filenames