import torch
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import random
import torchvision.transforms as transforms
from typing import List
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import interpolate
import torchvision

#df: dataframe
class adni_tadpole(Dataset):
    def __init__(self, image_clinical_df, transform=None):
        self.image_clinical_df = image_clinical_df
        self.transform = transform

    def __len__(self):
        return len(self.image_clinical_df)

    def __getitem__(self, idx):
        im_series = self.image_clinical_df.iloc[idx]
        im_pil = Image.open(str(im_series["image_path"]))
        #im_pil = self.transform(im_pil)
        im_tensor = torch.tensor(np.array(im_pil)).unsqueeze(0)
        im_tensor = im_tensor[None, ...]
        im_tensor = interpolate(im_tensor, size=92, mode='bilinear', align_corners=True)
        im_tensor = im_tensor.squeeze(0) 
        age,  = (im_series["age_precise"]-54.4)/38.2, # im_series['index']
        #i_index = im_series['index']
        sample = {'image':im_tensor, 'age': age,}
        return sample
    

def get_data_loader(image_clinical_df, batch_size, split_set: str = 'train', which_label: str = "age", transform= None):
    assert split_set in ["train", "val", "test"]
    default_kwargs = {"shuffle": False , "num_workers": 4, "drop_last": True, "batch_size": batch_size}
    dataset = adni_tadpole(image_clinical_df, transform)
        
    if split_set != "test":
        val_ratio = 0.2
        split = torch.utils.data.random_split(dataset,
                                              [len(dataset)-int(len(dataset) * val_ratio), int(len(dataset) * val_ratio)],
                                              generator=torch.Generator().manual_seed(42))
        dataset = split[0] if split_set == "train" else split[1]
    
    
    loader = torch.utils.data.DataLoader(dataset, **default_kwargs)
    while True:
        yield from loader
        
  

        