import json
import math
import os
import pickle
import random
import typing
from typing import Callable, Dict, List, Optional, Tuple

import lmdb
import numpy as np
from pathlib import Path
import torch
import torch_geometric
from torch_geometric.data import Data, Dataset

from src.constants import DATASET_PRIORS, DATASET_PRIORS_PATH, MAX_ATOMS
from src.utils import misc_utils, spatial_utils

LOGGER = misc_utils.get_logger(__name__)


def save_dataset_priors(values: List[float], dataset_name: str) -> None:
    """Convert a list of values to a probability distribution and save to JSON.

    Args:
        values: List of numerical values to convert to probabilities
        dataset_name: Name of the dataset/distribution to save
    """
    # Count frequency of each unique value
    unique_values, counts = np.unique(values, return_counts=True)

    # Create probability distribution
    total_counts = np.sum(counts)
    probabilities = {
        str(int(val)): float(count / total_counts) for val, count in zip(unique_values, counts)
    }

    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(DATASET_PRIORS_PATH), exist_ok=True)

    # Load existing distributions if file exists
    distributions = DATASET_PRIORS

    # Add new distribution, overwriting if it already exists
    if dataset_name in distributions:
        print(f"Warning: Overwriting existing distribution for {dataset_name}")

    distributions[dataset_name] = probabilities
    print(distributions)

    # Save updated distributions
    with open(DATASET_PRIORS_PATH, "w") as f:
        json.dump(distributions, f, indent=4)


class GraphDataset(Dataset):
    def __init__(
        self,
        config: Dict,
        cache_dir: str,
        data_list: List[Data],
        sample_conformer: bool = False,
        coord_mask_value: float = 0.0,
    ):
        super(GraphDataset, self).__init__()
        self.data_list = data_list
        self.coord_mask_value = coord_mask_value
        self.sample_conformer = sample_conformer
        self.config = config
        if config.paths.use_lmdb:
            self.conformers_path = os.path.join(cache_dir, "conformers.lmdb")
        else:
            self.conformers_path = os.path.join(cache_dir, "conformers/final_conformers")

        self.pharmacophore_path = os.path.join(cache_dir, "pharmacophores.lmdb")

    def len(self):
        return len(self.data_list)

    def get(self, idx: int):
        data = self.data_list[idx]
        # Pick a conformer
        conf_xyz = spatial_utils.select_conformer(
            data.data_index, self.conformers_path, random=self.sample_conformer
        )
        conf_sdf = Path(conf_xyz).with_suffix(".sdf").stem

        coordinates, coords_mask = spatial_utils.xyz_to_coordinates(
            conf_xyz,
            data.smiles,  # Assuming SMILES is stored in the Data object
            data.x,  # Use existing node features
            data.edge_attr.reshape(
                data.x.size(0), data.x.size(0), -1
            ),  # Use existing edge features
            self.coord_mask_value,
        )

        data.coordinates = coordinates
        data.coords_mask = coords_mask

        if self.config.spatial.pharmacophore_conditioning:
            pharm_types, pharm_pos, pharm_padding_mask = spatial_utils.get_pharmacophore(
                conf_sdf, self.pharmacophore_path, self.config.spatial.pharmacophore_subset
            )
            data.pharm_types = pharm_types
            data.pharm_pos = pharm_pos
            data.pharm_padding_mask = pharm_padding_mask

        return data


def get_rgfn_graph_dataset(
    cache_dir: str,
    config: Dict,
    train_size: float = 0.9,
    validation_size: float = 0.1,
    test_size: float = 0.0,
) -> Dict[str, List[Data]]:
    split_names = ["train", "validation", "test"]

    # Check if splits already exist

    extension = f"/overfit_{config.spatial.n_overfit}" if config.spatial.overfit else "/full"
    cache_dir_splits = cache_dir + extension
    os.makedirs(cache_dir_splits, exist_ok=True)

    if not all(
        [
            misc_utils.fsspec_exists(os.path.join(cache_dir_splits, f"{split}.pt"))
            for split in split_names
        ]
    ):
        # Load the complete data list
        data_list = torch.load(os.path.join(cache_dir, "dataset_list_full.pt"))

        # Handle overfit case
        if config.spatial.overfit:
            data_list = data_list[: config.spatial.n_overfit]
            # data_list = [data_list[0] for _ in range(config.spatial.n_overfit)]

        # Save node length distribution
        node_lengths = [data.x.size(0) for data in data_list]
        save_dataset_priors(node_lengths, "rgfn_graph")

        # Calculate split sizes
        total_size = len(data_list)
        train_end = int(total_size * train_size)
        val_end = train_end + int(total_size * validation_size)
        test_end = val_end + int(total_size * test_size)

        # Split the data
        splits = {
            "train": data_list[:train_end],
            "validation": data_list[train_end:val_end],
            "test": data_list[val_end:test_end],
        }

        # Save the splits
        for split, data in splits.items():
            torch.save(data, os.path.join(cache_dir_splits, f"{split}.pt"))
    else:
        # Load existing splits
        splits = {
            split: torch.load(os.path.join(cache_dir_splits, f"{split}.pt"))
            for split in split_names
        }

    # Wrap data lists into Dataset objects
    datasets = {}
    for split in split_names:
        datasets[split] = GraphDataset(
            config=config,
            cache_dir=cache_dir,
            data_list=splits[split],
            sample_conformer=True
            if split == "train" and config.spatial.sample_conformer
            else False,
            coord_mask_value=config.spatial.coord_mask_value,
        )

    return datasets


def get_dataset(dataset_name: str, mode: str, cache_dir: str, config: Dict) -> Dataset:
    LOGGER.info(f"Generating new data at: {cache_dir}")
    if dataset_name == "rgfn_graph":
        dataset = get_rgfn_graph_dataset(cache_dir=cache_dir, config=config)
        data = dataset[mode]
        return data


def get_dataloaders(
    config: Dict,
    skip_train: bool = False,
    skip_valid: bool = False,
    valid_seed: Optional[int] = None,
) -> Tuple[
    Optional[torch_geometric.loader.DataLoader], Optional[torch_geometric.loader.DataLoader]
]:
    loader_class = (
        torch_geometric.loader.DataLoader
        if config.data.train == "rgfn_graph"
        else torch.utils.data.DataLoader
    )
    num_gpus = torch.cuda.device_count()
    cache_dir = config.data.cache_dir

    assert config.loader.global_batch_size == (
        config.loader.batch_size
        * config.trainer.num_nodes
        * num_gpus
        * config.trainer.accumulate_grad_batches
    )
    if config.loader.global_batch_size % (num_gpus * config.trainer.accumulate_grad_batches) != 0:
        raise ValueError(
            f"Train Batch Size {config.training.batch_size}"
            f"not divisible by {num_gpus} gpus with accumulation "
            f"{config.trainer.accumulate_grad_batches}."
        )
    if config.loader.eval_global_batch_size % num_gpus != 0:
        raise ValueError(
            f"Eval Batch Size for {config.eval.batch_size} " f"not divisible by {num_gpus}."
        )
    if skip_train:
        train_set = None
    else:
        train_set = get_dataset(
            config.data.train,
            mode="train",
            cache_dir=cache_dir,
            config=config,
        )

    validation_split = "validation"
    if skip_valid:
        valid_set = None
    else:
        valid_set = get_dataset(
            config.data.valid,
            mode=validation_split,
            cache_dir=cache_dir,
            config=config,
        )

    if skip_train:
        train_loader = None
    else:
        train_loader = loader_class(
            train_set,
            batch_size=config.loader.batch_size,
            num_workers=config.loader.num_workers,
            pin_memory=config.loader.pin_memory,
            shuffle=not config.data.streaming,
            persistent_workers=config.loader.num_workers > 0,
        )
    if skip_valid:
        valid_loader = None
    else:
        if valid_seed is None:
            shuffle_valid = False
            generator = None
        else:
            shuffle_valid = True
            generator = torch.Generator().manual_seed(valid_seed)
        valid_loader = loader_class(
            valid_set,
            batch_size=config.loader.eval_batch_size,
            num_workers=config.loader.num_workers,
            pin_memory=config.loader.pin_memory,
            shuffle=shuffle_valid,
            generator=generator,
        )

    return train_loader, valid_loader


# Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py


class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):
    def __init__(self, *args, generator=None, **kwargs):
        # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
        # which should be reproducible if pl.seed_everything was called beforehand.
        # This means that changing the seed of the experiment will also change the
        # sampling order.
        if generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator().manual_seed(seed)
        kwargs.pop("shuffle", None)
        super().__init__(*args, generator=generator, **kwargs)
        self.counter = 0
        self.restarting = False

    def state_dict(self):
        return {"random_state": self.generator.get_state(), "counter": self.counter}

    def load_state_dict(self, state_dict):
        self.generator.set_state(state_dict.get("random_state"))
        self.counter = state_dict["counter"]
        # self.start_counter = self.counter
        self.restarting = True

    # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
    # epoch, and subsequent epoch will have very few batches.

    def __iter__(self) -> typing.Iterator[int]:
        n = len(self.data_source)

        self.state = self.generator.get_state()
        indices = torch.randperm(n, generator=self.generator).tolist()

        if not self.restarting:
            self.counter = 0
        else:
            indices = indices[self.counter :]
            self.restarting = False

        for index in indices:
            self.counter += 1
            yield index

        self.counter = 0


class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.counter = 0
        self.restarting = False

    def state_dict(self):
        return {"epoch": self.epoch, "counter": self.counter}

    def load_state_dict(self, state_dict):
        self.epoch = state_dict["epoch"]
        self.counter = state_dict["counter"]
        self.restarting = True

    # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
    # epoch, and subsequent epoch will have very few batches.
    def __iter__(self):
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        if not self.restarting:
            self.counter = 0
        else:
            indices = indices[self.counter :]
            self.restarting = False

        for index in indices:
            self.counter += 1
            yield index

        self.counter = 0
