import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from pytorch_lightning import LightningDataModule

from pathlib import Path
import sys

# Add the TorchSpatial folder to the Python module search path if needed
torchspatial_path = Path("./location-embeddings-3d/dependencies/TorchSpatial")
sys.path.append(str(torchspatial_path / "main"))

from datasets import load_dataset

class Inat2018DirectDataset(Dataset):
    def __init__(self, data, split='train', 
        sample_fraction=1.0, 
        variable_cut_off = 8142, # subset prediction classes
        ):
        """
        Args:
          data (dict): The output of load_dataset.
          split (str): The dataset split to use ('train' or 'val').
        """
        self.data = data
        self.split = split
        self.variable_cut_off=variable_cut_off
        self.locs = torch.Tensor(data[f'{split}_locs'])
        self.valid_indices = torch.Tensor(data[f'{split}_inds']).long()
        self.classes = torch.Tensor(data[f'{split}_classes']).long()
        self.selected_variable_indices = self.classes < variable_cut_off
        self.classes = self.classes[self.selected_variable_indices]
        self.dates = torch.Tensor(data[f'{split}_dates']) # assume in range [0, 1]
        self.times = (
           self.dates - 0.5 # shift to range [-0.5, 0.5]
        ) * 2 # scale to range [-1, 1]
        self.vision_logits = torch.Tensor(data[f'{split}_preds'])
        if split == 'train':
            self.vision_logits = self.vision_logits[self.valid_indices, :]

        self.vision_logits = self.vision_logits[self.selected_variable_indices, :self.variable_cut_off]

        print(f"self.times: {self.times}")
        
        self.lon_lat_time_vec = torch.concat(
          [self.locs, self.times.unsqueeze(1)], dim=1
        )[self.selected_variable_indices]

        # select a subset of the data
        if sample_fraction < 1.0:
          self.num_samples = int(len(self.locs) * sample_fraction)
          indices = torch.randperm(len(self.locs))[:self.num_samples]
          self.lon_lat_time_vec = self.lon_lat_time_vec[indices]
          self.classes = self.classes[indices]
          self.vision_logits = self.vision_logits[indices]
        else:
          self.num_samples = len(self.locs)
         
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        """
        Returns a single data point as a tuple.
        """
        return self.lon_lat_time_vec[idx], self.classes[idx]

class Inat2018VisionLocationDataset(Dataset):
  def __init__(self, data, split='train', sample_fraction=0.05):
    """
    Args:
      data (dict): The output of load_dataset.
      split (str): The dataset split to use ('train' or 'val').
    """
    self.data = data
    self.split = split
    self.locs = torch.Tensor(data[f'{split}_locs'])
    self.valid_indices = torch.Tensor(data[f'{split}_inds']).long()
    self.classes = torch.Tensor(data[f'{split}_classes']).long()
    self.dates = torch.Tensor(data[f'{split}_dates']) # assume in range [0, 1]
    self.times = (
       self.dates - 0.5 # shift to range [-0.5, 0.5]
    ) * 2 # scale to range [-1, 1]

    self.vision_logits = torch.Tensor(data[f'{split}_preds'])
    print(f"val_inds.max(): {self.valid_indices.max()}")
    print(f"val_inds.min(): {self.valid_indices.min()}")
    if split == 'train':
        self.vision_logits = self.vision_logits[self.valid_indices, :]

    self.lon_lat_time_vec = torch.concat(
      [self.locs, self.times.unsqueeze(1)], dim=1
    )

    # select a subset of the data
    if sample_fraction < 1.0:
      self.num_samples = int(len(self.locs) * sample_fraction)
      indices = torch.randperm(len(self.locs))[:self.num_samples]
      self.lon_lat_time_vec = self.lon_lat_time_vec[indices]
      self.classes = self.classes[indices]
      self.vision_logits = self.vision_logits[indices]
    else:
      self.num_samples = len(self.locs)
    
  def __len__(self):
    return self.num_samples

  def __getitem__(self, idx):
    """
    Returns a single data point as a tuple.
    """
    return self.lon_lat_time_vec[idx], self.vision_logits[idx], self.classes[idx]

class TorchSpatialDataModule(LightningDataModule):
    def __init__(self, dataset, batch_size=32, num_workers=4, subset_fraction=0.05, variable_cut_off=8142):
        """
        Args:
            data (dict): The output of load_dataset.
            batch_size (int): Batch size for the DataLoader.
            num_workers (int): Number of workers for the DataLoader.
            subset_fraction (float): Fraction of data to sample from the dataset.
        """

        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.subset_fraction = subset_fraction
        self.variable_cut_off = variable_cut_off

    def setup(self, stage=None):
        return_op_val = load_dataset(params = {"dataset" : self.dataset,
                "inat2018_resolution": "standard",
                "load_img" : False,
                "cnn_pred_type" : "full",
                "train_sample_ratio": 1.0,
                "cnn_model": "inception_v3",
                "regress_dataset" : [],
                "meta_type": "orig_meta"
                },
                eval_split='val',
                train_remove_invalid=True,
                eval_remove_invalid=True,
                load_cnn_predictions=True,
                load_cnn_features=False,
                load_cnn_features_train=False,
                )
        self.train_val_dataset = Inat2018DirectDataset(return_op_val, split='train', 
            sample_fraction=self.subset_fraction,
            variable_cut_off=self.variable_cut_off,
        )
        self.test_dataset = Inat2018DirectDataset(return_op_val, split='val', 
            sample_fraction=self.subset_fraction,
            variable_cut_off=self.variable_cut_off,
        )
        self.test_vision_logits = self.test_dataset.vision_logits

        self.train_val_dataset = TensorDataset(
            self.train_val_dataset.lon_lat_time_vec,
            self.train_val_dataset.classes,
        )

        train_frac = 0.9
        num_train_samples = int(len(self.train_val_dataset) * train_frac)
        num_val_samples = len(self.train_val_dataset) - num_train_samples
        
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(self.train_val_dataset, [num_train_samples, num_val_samples])

        self.test_dataset = TensorDataset(   
            self.test_dataset.lon_lat_time_vec,
            self.test_dataset.vision_logits,
            self.test_dataset.classes,
        )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)


class TorchSpatialDataModule_VisionLocation(LightningDataModule):
    def __init__(self, dataset, batch_size=32, num_workers=4, subset_fraction=0.05):
        """
        Args:
            data (dict): The output of load_dataset.
            batch_size (int): Batch size for the DataLoader.
            num_workers (int): Number of workers for the DataLoader.
            subset_fraction (float): Fraction of data to sample from the dataset.
        """

        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.subset_fraction = subset_fraction

    def setup(self, stage=None):
        return_op_val = load_dataset(params = {"dataset" : self.dataset,
                "inat2018_resolution": "standard",
                "load_img" : False,
                "cnn_pred_type" : "full",
                "train_sample_ratio": 1.0,
                "cnn_model": "inception_v3",
                "regress_dataset" : [],
                "meta_type": "orig_meta"
                },
                eval_split='val',
                train_remove_invalid=True,
                eval_remove_invalid=True,
                load_cnn_predictions=True,
                load_cnn_features=False,
                load_cnn_features_train=False,
                )

        # Create datasets for train and validation splits
        self.train_val_dataset = Inat2018VisionLocationDataset(return_op_val, split='train', sample_fraction=self.subset_fraction)
        self.test_dataset = Inat2018VisionLocationDataset(return_op_val, split='val', sample_fraction=self.subset_fraction)

        self.train_val_dataset = TensorDataset(
            self.train_val_dataset.lon_lat_time_vec,
            self.train_val_dataset.vision_logits,
            self.train_val_dataset.classes,
        )


        train_frac = 0.95
        num_train_samples = int(len(self.train_val_dataset) * train_frac)
        num_val_samples = len(self.train_val_dataset) - num_train_samples
        
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(self.train_val_dataset, [num_train_samples, num_val_samples])

        self.test_dataset = TensorDataset(   
            self.test_dataset.lon_lat_time_vec,
            self.test_dataset.vision_logits,
            self.test_dataset.classes,
        )


    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
    
