from pathlib import Path

import pandas as pd
import numpy as np

from torch.utils.data import Dataset, random_split
from .image_utils import IndexedDataset


class CaliforniaHousing(Dataset):
    def __init__(self, cali_path: str, domain: str, standardize: bool=True):
        DOMAIN_COL = "ocean_proximity"
        TARGET_COL = "median_house_value"

        p = Path(cali_path, "housing.csv")
        df = pd.read_csv(str(p)).dropna()

        source_data = df.query(f"{DOMAIN_COL} != 'NEAR BAY'").drop(
            columns=[DOMAIN_COL])
        target_data = df.query(f"{DOMAIN_COL} == 'NEAR BAY'").drop(
            columns=[DOMAIN_COL])

        if standardize:
            source_mean = source_data.mean()
            source_std = source_data.std(ddof=0)

            source_data = (source_data - source_mean) / (source_std + 1e-8)
            target_data = (target_data - source_mean) / (source_std + 1e-8)

        data = source_data if domain == 'far_bay' else target_data

        self.labels = data[TARGET_COL].to_numpy()
        self.data = data.drop(
            columns=[TARGET_COL]).to_numpy().astype(np.float32)

    def __len__(self) -> int:
        return self.labels.shape[0]

    def __getitem__(self, i: int) -> tuple[np.ndarray, float]:
        return self.data[i], self.labels[i]



def get_california_housing(fetch_dset: dict, domain: str, cali_path: str) -> tuple[Dataset, Dataset]:
    # default no augmentation for both train and test set
    aug_type = fetch_dset.get("aug_type", 'val')
    tr_val_split = fetch_dset.get("tr_val_split", 1000)
    
  
    ds = CaliforniaHousing(cali_path, domain)

    if tr_val_split > 1:
        train_num = int(len(ds) * tr_val_split)
        val_num = len(ds) - train_num
        train_ds, val_ds = random_split(ds, [train_num, val_num])
    else:
        train_ds = ds
        val_ds = ds
        
    
    train_aug_ds = None
    val_aug_ds = None
    train_aug_ds = train_ds
    val_aug_ds = val_ds

    return IndexedDataset(train_ds), IndexedDataset(train_aug_ds), IndexedDataset(val_ds), IndexedDataset(val_aug_ds)







