import pandas as pd
from torch.utils.data import Dataset, DataLoader
import os
import torch
import astropy.io.fits as fits
import pdb

class RedShift(Dataset):
    def __init__(self, transform=None, val_split = None ,train=True, seed=0):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            train (bool): If True, loads training data; otherwise, validation data.
            val_split (float): Fraction of the dataset to use for validation.
            seed (int): Seed for reproducibility.
        """
        csv_file = "./Astro_DataSet/redshift/sdss_dr12_with_filenames_train_test.csv"
        root_dir = "./Astro_DataSet/redshift/dr_12_redshift/"

        self.df = pd.read_csv(csv_file)

        self.root_dir = root_dir
        self.transform = transform
        self.train = train
        self.seed = seed

        self.df = self.df.sample(frac=1, random_state=self.seed).reset_index(drop=True).sample(41650)

        if self.train:
            self.df = self.df[:10100].reset_index()

        else:
            self.df = self.df[10100:].reset_index()

    def __len__(self):

        return len(self.df)

    def __getitem__(self, idx):
        path = self.df.at[idx, "filename"]
        img_path = os.path.join(self.root_dir, path)

        image = fits.open(img_path)[0].data.byteswap().newbyteorder()

        label = self.df.at[idx, "redshift"]

        # c * h * w
        image = torch.tensor(image, dtype=torch.float32)
    
        image = (image - image.mean(dim=(1, 2)).view(-1, 1, 1)) / image.std(dim=(1, 2)).view(-1, 1, 1)
        image = image[:, 16:240, 16:240]

        return image, torch.tensor(label, dtype=torch.float32).unsqueeze(0)


def redshift(val_split=0.2, seed=0):

    output_size = 1

    train_dataset = RedShift(train=True, val_split=val_split, seed=seed)
    val_dataset = RedShift(train=False, val_split=val_split, seed=seed)

    return train_dataset, val_dataset, output_size