import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningDataModule
import numpy as np
import pandas as pd
from utils import get_presence_map_gt


class Occurance(Dataset):
    def __init__(self, text_embeds_path, parquet_path, bins=[900, 1800]):
        self.text_embeds_path = text_embeds_path
        self.text_embeds = np.load(self.text_embeds_path, allow_pickle=True)
        self.text_keys = list(self.text_embeds[()].keys())
        self.num_species = len(self.text_keys)
        self.parquet_path = parquet_path
        self.obs = pd.read_parquet(
            self.parquet_path,
            engine="fastparquet",
            columns=["species", "year", "decimalLatitude", "decimalLongitude"],
        )
        self.bins = bins

    def __len__(self):
        return self.num_species

    def __getitem__(self, idx):
        species = self.text_keys[idx]
        text_embed = torch.tensor(self.text_embeds[()][species]).squeeze(0)
        hist = get_presence_map_gt(self.obs, species, bins=self.bins)
        hist = torch.tensor(hist).unsqueeze(0)
        return text_embed, hist


class OccuranceEnv(Dataset):
    def __init__(self, text_embeds_path, parquet_path, env_cov_path, bins=[900, 1800]):
        self.text_embeds_path = text_embeds_path
        self.text_embeds = np.load(self.text_embeds_path, allow_pickle=True)
        self.text_keys = list(self.text_embeds[()].keys())
        self.num_species = len(self.text_keys)
        self.parquet_path = parquet_path
        self.obs = pd.read_parquet(
            self.parquet_path,
            engine="fastparquet",
            columns=["species", "year", "decimalLatitude", "decimalLongitude"],
        )
        self.env_cov_path = env_cov_path
        self.env_cov = np.load(self.env_cov_path).transpose(2, 0, 1)
        self.bins = bins

    def __len__(self):
        return self.num_species

    def __getitem__(self, idx):
        species = self.text_keys[idx]
        text_embed = torch.tensor(self.text_embeds[()][species]).squeeze(0)
        hist = get_presence_map_gt(self.obs, species, bins=self.bins)
        hist = torch.tensor(hist).unsqueeze(0)
        return torch.tensor(self.env_cov), text_embed, hist


class OccuranceDataModule(LightningDataModule):
    def __init__(
        self,
        text_embeds_path,
        parquet_path,
        env_cos_path=None,
        bins=[900, 1800],
        batch_size=1,
        shuffle=True,
        num_workers=12,
    ):
        super().__init__()
        self.text_embeds_path = text_embeds_path
        self.parquet_path = parquet_path
        self.env_cos_path = env_cos_path
        self.bins = bins
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers

    def setup(self, stage=None):
        if self.env_cos_path:
            self.train_dataset = OccuranceEnv(
                self.text_embeds_path,
                self.parquet_path,
                self.env_cos_path,
                bins=self.bins,
            )
        else:
            self.train_dataset = Occurance(
                self.text_embeds_path, self.parquet_path, bins=self.bins
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers,
        )
