import sys
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from pytorch_lightning import LightningDataModule

# Add the TorchSpatial folder to the Python module search path
torchspatial_path = Path("./dependencies/TorchSpatial")
sys.path.append(str(torchspatial_path / "main"))
from datasets import load_dataset  # adjust import if needed

class Inat2018DirectDataset(Dataset):
    """
    Dataset for direct location and class data from iNat2018.
    """
    def __init__(self, data: dict, split: str = 'train', sample_fraction: float = 0.05):
        self.data = data
        self.split = split
        self.locs = torch.Tensor(data[f'{split}_locs'])
        self.classes = torch.Tensor(data[f'{split}_classes']).long()
        self.dates = torch.Tensor(data[f'{split}_dates'])
        self.times = (self.dates - 0.5) * 2
        self.lon_lat_time_vec = torch.concat([self.locs, self.times.unsqueeze(1)], dim=1)
        if sample_fraction < 1.0:
            self.num_samples = int(len(self.locs) * sample_fraction)
            indices = torch.randperm(len(self.locs))[:self.num_samples]
            self.lon_lat_time_vec = self.lon_lat_time_vec[indices]
            self.classes = self.classes[indices]
        else:
            self.num_samples = len(self.locs)
    def __len__(self):
        return self.num_samples
    def __getitem__(self, idx):
        return self.lon_lat_time_vec[idx], self.classes[idx]

class Inat2018GeoPriorDataset(Dataset):
    """
    Dataset for location, CNN predictions, and class data from iNat2018.
    """
    def __init__(self, data: dict, split: str = 'train', sample_fraction: float = 0.05):
        self.data = data
        self.split = split
        self.locs = torch.Tensor(data[f'{split}_locs'])
        self.classes = torch.Tensor(data[f'{split}_classes']).long()
        self.times = (torch.Tensor(data[f'{split}_dates']) - 0.5) * 2
        self.cnn_preds = torch.Tensor(data[f'{split}_preds'])
        self.lon_lat_time_vec = torch.concat([self.locs, self.times.unsqueeze(1)], dim=1)
        if sample_fraction < 1.0:
            self.num_samples = int(len(self.locs) * sample_fraction)
            indices = torch.randperm(len(self.locs))[:self.num_samples]
            self.lon_lat_time_vec = self.lon_lat_time_vec[indices]
            self.cnn_preds = self.cnn_preds[indices]
            self.classes = self.classes[indices]
        else:
            self.num_samples = len(self.locs)
    def __len__(self):
        return self.num_samples
    def __getitem__(self, idx):
        return self.lon_lat_time_vec[idx], self.cnn_preds[idx], self.classes[idx]

class TorchSpatialDataModule(LightningDataModule):
    """
    PyTorch Lightning DataModule for spatial datasets.
    """
    def __init__(
        self,
        dataset: str,
        batch_size: int = 32,
        num_workers: int = 4,
        subset_fraction: float = 0.05
    ):
        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.subset_fraction = subset_fraction

    def setup(self, stage: str = None) -> None:
        return_op_val = load_dataset(params = {"dataset" : self.dataset,
                "inat2018_resolution": "standard",
                "load_img" : False,
                "cnn_pred_type" : "full",
                "train_sample_ratio": 1.0,
                "cnn_model": "inception_v3",
                "regress_dataset" : [],
                "meta_type": "orig_meta"
                },
                eval_split='val',
                train_remove_invalid=True,
                eval_remove_invalid=True,
                load_cnn_predictions=True,
                load_cnn_features=False,
                load_cnn_features_train=False,
                )
        return_op_test = load_dataset(params = {"dataset" : self.dataset,
                "inat2018_resolution": "standard",
                "load_img" : True,
                "cnn_pred_type" : "full",
                "train_sample_ratio": 1.0,
                "cnn_model": "inception_v3",
                "regress_dataset" : [],
                "meta_type": "orig_meta"
                },
                eval_split='test',
                train_remove_invalid=True,
                eval_remove_invalid=True,
                load_cnn_predictions=True,
                load_cnn_features=False,
                load_cnn_features_train=False,
                )
        self.train_dataset = Inat2018DirectDataset(return_op_val, split='train', sample_fraction=self.subset_fraction)
        self.val_dataset = Inat2018DirectDataset(return_op_val, split='val', sample_fraction=self.subset_fraction)
        self.test_dataset = Inat2018DirectDataset(return_op_test, split='val', sample_fraction=1.0)
        self.test_cnn_preds_dataset = Inat2018GeoPriorDataset(return_op_test, split='val', sample_fraction=1.0)
        self.train_dataset = TensorDataset(self.train_dataset.lon_lat_time_vec, self.train_dataset.classes)
        self.val_dataset = TensorDataset(self.val_dataset.lon_lat_time_vec, self.val_dataset.classes)
        self.test_dataset = TensorDataset(self.test_dataset.lon_lat_time_vec, self.test_dataset.classes)
        self.test_cnn_preds_dataset = TensorDataset(self.test_cnn_preds_dataset.lon_lat_time_vec, self.test_cnn_preds_dataset.cnn_preds, self.test_cnn_preds_dataset.classes)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
    def test_dataloader(self):
        return DataLoader(self.test_cnn_preds_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
    def test_cnn_preds_dataloader(self):
        return DataLoader(self.test_cnn_preds_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)