import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.datasets import QM9, MD17, ModelNet
from atomic_datasets import QM9 as QM9Atomic
from torch_geometric.loader import DataLoader
from torch.utils.data import Dataset
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch.utils.data import Dataset
import numpy as np
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split
import time
import utils
from utils import rotate_molecule, generate_random_quaternion, canonicalize_molecule, generate_many_random_quaternions#, normalize_outputs, unnormalize_outputs
import copy
import random
import os
import torch_geometric
import e3tools
import matplotlib.pyplot as plt
from torch_geometric.transforms import SamplePoints
from torch_geometric.nn import knn_graph



from typing import Optional, Callable
from glob import glob

class RotatedQM9Dataset(QM9):
    def __init__(self, root, quaternions, split="train", apply_random_rot_frac=False, quat_ids=None):
        super().__init__(root)
        self.split = split

        #self.data, self.slices = self.processed_paths()
        self.apply_random_rot_frac = apply_random_rot_frac
        print('RANDOM ROT: ', self.apply_random_rot_frac)

        # generate random quaternions up here
        ds_len = super().__len__
        self.quaternions = quaternions 
        self.quat_ids = quat_ids


    def __getitem__(self, idx):
        data = super().__getitem__(idx)

        if self.apply_random_rot_frac > 0:
          if torch.rand(1)[0] < self.apply_random_rot_frac:
            data = rotate_molecule(data, generate_random_quaternion())

        quaternion = self.quaternions[idx] #generate_random_quaternion()
        rotated_data = rotate_molecule(data, quaternion)
        #print('quaternion shape at generation', quaternion.shape)
        rotated_data.quaternion = torch.tensor(quaternion, dtype=torch.float32).reshape(1,4)
        if self.quat_ids is not None:
          rotated_data.quat_id = self.quat_ids[idx]
          label = rotated_data.quat_id
        else:
          label = rotated_data.quaternion

        return rotated_data
    
class RotatedModelNetDataset(ModelNet):
    def __init__(self, root, quaternions, split="train", apply_random_rot_frac=False, quat_ids=None):
        super().__init__(root)
        self.split = split

        #self.data, self.slices = self.processed_paths()
        self.apply_random_rot_frac = apply_random_rot_frac
        print('RANDOM ROT: ', self.apply_random_rot_frac)

        # generate random quaternions up here
        ds_len = super().__len__
        self.quaternions = quaternions 
        self.quat_ids = quat_ids


    def __getitem__(self, idx):
        data = super().__getitem__(idx)

        if self.apply_random_rot_frac > 0:
          if torch.rand(1)[0] < self.apply_random_rot_frac:
            data = rotate_molecule(data, generate_random_quaternion())

        quaternion = self.quaternions[idx] #generate_random_quaternion()
        rotated_data = rotate_molecule(data, quaternion)
        #print('quaternion shape at generation', quaternion.shape)
        rotated_data.quaternion = torch.tensor(quaternion, dtype=torch.float32).reshape(1,4)
        if self.quat_ids is not None:
          rotated_data.quat_id = self.quat_ids[idx]
          label = rotated_data.quat_id
        else:
          label = rotated_data.quaternion

        return rotated_data
    



class LocalQM9Dataset(QM9):
    # try 1 at local dataset - too slow?
    def __init__(self, root):
        super().__init__(root)

    def __getitem__(self, idx):
        data = super().__getitem__(idx) 
        num_nodes = data.x.shape[0]
        node_idx = int(torch.randint(num_nodes, (1,)))
        sub_data = utils.get_neighborhood_data(data, node_idx)
        return sub_data
    
# TODO:
# class LocalModelNetDataset(ModelNet):
#     # try 1 at local dataset - too slow?
#     def __init__(self, root):
#         super().__init__(root)    


# try 2 at local
# from torch_geometric.data import Data, Dataset
#from torch_geometric.utils import subgraph
# this is the one that's used
class EfficientLocalQM9(torch_geometric.data.Dataset):
    def __init__(self, root, transform=None, pre_transform=None, specific_file='small2_processed_qm9.pt'):
        super().__init__(root, transform, pre_transform)
        print(f'loading from {specific_file}')
        self.root = root
        self.specific_file = specific_file
        self.full_file_path = os.path.join(root, specific_file)
        
        if not os.path.exists(self.full_file_path):
            self.generate_data()

        self.data_list = torch.load(self.full_file_path, weights_only=False)

    def generate_data(self):
        print('Should not be generating data! Your file was not found.')
        dataset = QM9(root=os.path.join(os.path.dirname(self.root), 'qm9'))
        dataset_len = len(dataset)
        some_graphs = []
        num_nodes_per_graph = 3
        start = time.time()
        counter = 0
        for i, mol in enumerate(dataset):
            if counter > 500000:
                break
            if i % 1000 == 0 and i>0:
                elapsed = (time.time() - start) / 60.0
                estimated_remaining = (elapsed / i) * (dataset_len - i)
                print(f'{i} {elapsed} min cumulative.    estimated remaining {estimated_remaining} min, or {estimated_remaining/60.0} hours')
            num_nodes = mol.pos.size(0)
            for ig in range(num_nodes_per_graph):
                node_idx = torch.randint(low=0, high=num_nodes, size=(1,))
                sub_data = utils.get_neighborhood_data(mol, node_idx, idx_to_use=counter)
                counter += 1
                some_graphs.append(sub_data)

        # Save the list of processed graphs
        torch.save(some_graphs, self.full_file_path)

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data_list[idx]
    
class QM7TensorDataset(torch_geometric.data.Dataset):
    """
    Data from QM7b https://www.nature.com/articles/s41597-019-0157-8/tables/2 CCSD level.
    Loads data directly from .xyz files into list.  
    
    """
    def __init__(self, file_dir, pattern = "molecule*.xyz", transform=None, pre_transform=None, filter_polar = False,filter_low_symm = False):
        super().__init__()
        self.file_dir = file_dir
        self.files = sorted(glob(os.path.join(self.file_dir,pattern)))
        # maybe shouldn't load into memory here but the dataset isn't that big atm
        # shouldn't canonicalize here bc then the entire dataset will automatically be canonicalized
        # if filter_polar, filter for nonzero dipoles
        if filter_polar:
            self.files = [f for f in self.files if np.linalg.norm(utils.parse_qm7_xyz_to_data(f)['dipole']) > 0.1]
            print(f'Filtered {len(self.files)} files with nonzero dipoles from {len(glob(os.path.join(self.file_dir,pattern)))} total files.') 
        ### TODO implement this tomorrow
        if filter_low_symm:
            pass
        self.data_list = [utils.parse_qm7_xyz_to_data(f) for f in self.files]
        self.transform = transform

    def __len__(self):
        return len(self.data_list)
    
    ### add edge index here 
    def add_edge_index(self, data):
        import e3tools
        data.edge_index = e3tools.radius_graph(data.pos,5.0,data.batch)
    
    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the data sample.
        
        Returns:
            sample (Data): The data sample at the given index.
        """
        sample = self.data_list[idx]

        self.add_edge_index(sample)
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample

class ModelNetDataset(Dataset):
    def __init__(self, root, name='40', transform=None, classes=-1):
        self.root = root
        self.name = name
        self.transform = transform
        self.classes = classes
        self.base_dataset = ModelNet(root=root, name=name, transform=transform)
        
        if classes != -1:
            if isinstance(classes, (int, np.integer)):
                classes = [classes]
            indices = torch.where(torch.tensor([label in classes for label in self.base_dataset.data.y]))[0]
            self.data = self.base_dataset[indices]
        else:
            self.data = self.base_dataset
     
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        data = self.data[idx]

        return data
    
class QM9AtomicPyG(torch.utils.data.Dataset):
    """QM9 dataset in PyTorch Geometric format."""

    def __init__(self, root_dir, start_index = None, end_index = None,transform: Optional[Callable] = None):
        super(QM9AtomicPyG, self).__init__()
        self.dataset = QM9Atomic(root_dir, start_index, end_index, check_with_rdkit=False)
        self.transform = transform
        self.root_dir = root_dir
        self.included_idxs, self.excluded_idxs = self.remove_uncharacterized()

    def __len__(self):
        return len(self.dataset)

    @staticmethod
    def reorder_properties(prop_dict):
        target_names = [
            "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve",
            "U0", "U", "H", "G", "Cv",
            "U0_atom", "U_atom", "H_atom", "G_atom", "A", "B", "C",
        ]

        name_to_key = {
            "U0": "u0", "U": "u298", "H": "h298", "G": "g298", "Cv": "cv",
            "U0_atom": "u0_atom", "U_atom": "u298_atom", "H_atom": "h298_atom", "G_atom": "g298_atom",
            "A": "A", "B": "B", "C": "C",
            "mu": "mu", "alpha": "alpha", "homo": "homo", "lumo": "lumo",
            "gap": "gap", "r2": "r2", "zpve": "zpve"
        }

        values = []
        for name in target_names:
            key = name_to_key[name]
            if key not in prop_dict:
                raise KeyError(f"Missing key '{key}' for target '{name}' in property dictionary.")
            values.append(prop_dict[key])

        return torch.as_tensor(values,dtype=torch.float32)
    
    #### todo add removing uncharacterized molecules to the class!!!
    def remove_uncharacterized(self):
        included_idxs, excluded_idxs = utils.remove_uncharacterized_molecules(self.root_dir)
        return included_idxs, excluded_idxs

    ### add edge index here 
    def add_edge_index(self, data):
        import e3tools
        data.edge_index = e3tools.radius_graph(data.pos,5.0,data.batch)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        sample = torch_geometric.data.Data(
            pos=torch.as_tensor(sample["nodes"]["positions"],dtype=torch.float32),
            z=torch.as_tensor(sample["nodes"]["species"],dtype=torch.long),
            y=self.reorder_properties(sample['properties']).unsqueeze(0)
        )
        
        self.add_edge_index(sample) 
        if self.transform:
            sample = self.transform(sample)
        return sample

class MetaDetectionDataset(torch.utils.data.Dataset):
    """
    Data class for the binary classification problem of distinguishing between transformed and untransformed datapoints

    Parameters:
        base_dataset (torch.utils.data.Dataset): the base dataset to be transformed

        transform_operator (function): operator that applies the transform to an element of base_dataset (NOW INCLUDING the label). 
            one may or may not want this operator to be deterministic; for example, a new rotation may or may not be randomly drawn every time the transform_operator is called. for this reason, transform_operator  also accepts idx (the index in the dataset) as input

        transformed_class_balance (float): number between 0.0 and 1.0, indicating what fraction of the original dataset to transform

    """
    # transform_operator can do the precomputed rotation stuff
    def __init__(self, base_dataset, transform_operator, transformed_class_balance=0.5, rotate_everything=False, ignore_label=True):
        super().__init__()

        self.base_dataset = base_dataset
        self.transform_operator = transform_operator
        self.transformed_class_balance = transformed_class_balance

        dataset_size = len(self.base_dataset)

        # pick out which indices will be transformed
        transformed_inds = random.sample(range(dataset_size), int(transformed_class_balance*dataset_size))
        
        transformed_mask = torch.zeros(dataset_size).bool()
        transformed_mask[transformed_inds] = True

        self.transformed_inds = transformed_inds
        self.untransformed_inds = torch.where(torch.logical_not(transformed_mask))[0]
        self.transformed_mask = transformed_mask

        self.rotate_everything = rotate_everything
        self.ignore_label = ignore_label

    def __len__(self):
        return len(self.base_dataset)

    def set_rotate_everything(self, value):
        self.rotate_everything = value

    def __getitem__(self, idx):
        data = self.base_dataset[idx]

        if self.transformed_mask[idx]:
            data = self.transform_operator(data, idx)
            label = 1
        else:
            if self.rotate_everything:
                data = self.transform_operator(data, idx)
            label = 0

        if self.ignore_label:
            try:
                if isinstance(data, (list, tuple)) and len(data) > 1:
                    data = data[0]
            except TypeError:
                pass

        return data, label

# this should probably be incorporated into the MetaDetectionDataset class
# but for now, it's separate
class CanonicalizationMetaDetectionDataset(Dataset):
    """
    Binary classification dataset to detect whether (c(x), y) is a valid pair.

    For each sample, we either:
    - Return a matching pair (c(x), y(x)) with label 0.
    - Return a mismatched pair (c(x), y(x')) with label 1.

    Args:
        base_dataset (Dataset): Dataset returning (x, y) pairs.
        equivariant_model (nn.Module): Model to compute canonicalized representation c(x).
        mismatch_fraction (float): Fraction of samples that should be mismatched (label=1).
    """
    def __init__(self, base_dataset, equivariant_model, mismatch_fraction=0.5):
        super().__init__()
        self.base_dataset = base_dataset
        self.equivariant_model = equivariant_model
        self.mismatch_fraction = mismatch_fraction

        dataset_size = len(self.base_dataset)
        num_mismatched = int(dataset_size * mismatch_fraction)

        # Precompute mismatched indices
        mismatched_inds = random.sample(range(dataset_size), num_mismatched)
        self.mismatched_mask = torch.zeros(dataset_size, dtype=torch.bool)
        self.mismatched_mask[mismatched_inds] = True

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        x, y = self.base_dataset[idx]
        c_x = self.equivariant_model(x)

        if self.mismatched_mask[idx]:
            # Pick a random different index for mismatched label
            alt_idx = random.choice([i for i in range(len(self.base_dataset)) if i != idx])
            _, y_alt = self.base_dataset[alt_idx]
            y = y_alt
            label = 1
        else:
            label = 0

        # Concatenate c(x) and y for the classifier input
        if isinstance(y, torch.Tensor):
            pair = torch.cat([c_x, y], dim=-1)
        else:
            # Assume y is scalar or numpy; convert to tensor first
            y_tensor = torch.tensor(y, dtype=c_x.dtype, device=c_x.device).unsqueeze(0)
            pair = torch.cat([c_x, y_tensor], dim=-1)

        return pair, label

    
class MetaPairDetectionDataset(torch.utils.data.Dataset):
    """
    Data class for the binary classification problem of distinguishing between transformed and untransformed datapoints

    Parameters:
        base_dataset (torch.utils.data.Dataset): the base dataset to be transformed. already includes the labels!

        mismatched_pair_balance (float): number between 0.0 and 1.0, indicating what fraction of the original dataset to transform. defaults to 0.5

    """
    # transform_operator can do the precomputed rotation stuff
    def __init__(self, base_dataset, mismatched_pair_balance=0.5):
        super().__init__()

        self.base_dataset = base_dataset
        self.mismatched_pair_balance = mismatched_pair_balance

        dataset_size = len(self.base_dataset)

        # pick out which indices will be transformed
        mismatched_inds = random.sample(range(dataset_size), int(mismatched_pair_balance*dataset_size))
        
        mismatch_mask = torch.zeros(dataset_size).bool()
        mismatch_mask[mismatched_inds] = True

        self.mismatched_inds = mismatched_inds
        self.matched_inds = torch.where(torch.logical_not(mismatch_mask))[0]
        self.mismatch_mask = mismatch_mask

        self.permutation_of_inds = np.random.permutation(dataset_size)

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        data = self.base_dataset[idx] 
        x, y = data

        if self.mismatch_mask[idx]:
            # 1: compute permutation
            label_ind = self.permutation_of_inds[idx]
            y = self.base_dataset[label_ind][1]
            label = 1
        else:
            # just get label, and return (data, label), class [0 or 1]
            label = 0

        return (x, y), label


class MetaAugmentedDataset(torch.utils.data.Dataset):
    """
    Data class for adjusting the distributions on orbits of an underlying base dataset.

    Parameters:
        base_dataset (torch.utils.data.Dataset): the base dataset to be transformed

        transform_list (list): a list of tuples (transform_operator, transformed_class_balance) that will be applied in order. transform_operator depends on the datapoint and idx and should be deterministic (otherwise, it can be applied in the dataloader like a usual train-time augmentation). transformed_class_balance is between 0.0 and 1.0 and describes the fraction of the dataset that this transform will apply to. **the transforms will be applied in the order that they are listed** 

            for example, transform_list might look like: [(apply_canonicalization_function, 1.0), (apply_rotation_function, 0.5)]
            after both are applied, the dataset is then half canonicalized, half randomly rotated. **we are assuming that the label is invariant to this transform! the code may need to be updated for equivariant tasks*

        label_operator (function): a function that returns the label for a given datapoint. in a simple (x,y) dataset, this might just take as input (x,y) and return y. for qm9, this is more complicated: extract y, and standardize the values. if None, no labels are made!

    """
    def __init__(self, base_dataset, transform_list, label_operator=None):
        super().__init__()

        self.base_dataset = base_dataset
        self.transform_list = transform_list
        self.transformed_inds = []
        self.untransformed_inds = []
        self.transformed_masks = []

        dataset_size = len(self.base_dataset)
        for _, transformed_class_balance in transform_list:

            # pick out which indices will be transformed
            transformed_inds = random.sample(range(dataset_size), int(transformed_class_balance*dataset_size))
            self.transformed_inds.append(transformed_inds)

            transformed_mask = torch.zeros(dataset_size).bool()
            transformed_mask[transformed_inds] = True
            self.transformed_masks.append(transformed_mask)

            untransformed_inds = torch.where(torch.logical_not(transformed_mask))[0]
            self.untransformed_inds.append(untransformed_inds)

        self.label_operator = label_operator

    def turn_off_transforms(self, indices, exclude_indices=None):
        """
        Turn off transforms for the given data indices.
        
        Args:
            indices: indices to turn off transforms for
            exclude_indices: list of transform indices to *exclude* (i.e., keep on)
        """
        exclude_indices = set(exclude_indices or [])
        for i, tm in enumerate(self.transformed_masks):
            if i not in exclude_indices:
                tm[indices] = False
                self.transformed_masks[i] = tm

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        data = self.base_dataset[idx]

        # changed order so that label_operator is last!
        for i, (transform_operator, _) in enumerate(self.transform_list):
            if self.transformed_masks[i][idx]:
                data = transform_operator(data, idx)

        if self.label_operator is not None:
            data, label = self.label_operator(data)
            return data, label
        else:
            return data


class DetectRotatedQM9Dataset(QM9):
    def __init__(self, root, split="train", transformed_class_balance=0.5, apply_random_rot_frac=0, canonicalize=False, trivial_test=False, fixed_transform=True, fixed_rotation=True, fixed_quaternions=None): #, quat_ids=None):
        
        # fixed_transform: fix which datapoints are from "transformed" dataset and which stay original
        # fixed_rotation: whenever a rotation happens, is it going to be the same rotation always for that datapoint, or drawn independently at each query time?

        super().__init__(root)
        self.split = split

        self.apply_random_rot_frac = apply_random_rot_frac
        print('RANDOM ROT: ', self.apply_random_rot_frac)

        self.transformed_class_balance = transformed_class_balance

        # generate random quaternions up here
        ds_len = super().__len__

        num_elts = ds_len()

        if fixed_transform:
            do_rotation = torch.rand(num_elts) < apply_random_rot_frac
            self.do_rotation = do_rotation
            transformed_mask = torch.rand(num_elts) < transformed_class_balance
            self.transformed_mask = transformed_mask
        else:
            self.transformed_mask = None
            self.do_rotation = None 
        if fixed_rotation and fixed_quaternions is None:
          fixed_quaternions = generate_many_random_quaternions(num_elts)

        self.fixed_quaternions = fixed_quaternions
        self.fixed_transform = fixed_transform 
        self.fixed_rotation = fixed_rotation 

        self.canonicalize = canonicalize # before doing anything, canonicalize the data

        self.trivial_test = trivial_test


    def __getitem__(self, idx, verbose=False):
        data = super().__getitem__(idx)

        if self.canonicalize:
            data = canonicalize_molecule(data)

        if self.fixed_transform:
            in_transformed_set = self.transformed_mask[idx]
            do_rotation = self.do_rotation[idx]
        else:
            in_transformed_set = (torch.rand(1) < self.transformed_class_balance)
            do_rotation = (torch.rand(1)) < self.apply_random_rot_frac

        if in_transformed_set:
            if do_rotation:
                if self.fixed_rotation:
                    rot_to_apply = self.fixed_quaternions[idx]
                else:
                    rot_to_apply = generate_random_quaternion()
                new_data = rotate_molecule(data, rot_to_apply)
                if self.trivial_test:
                    new_data.pos = 100*torch.ones(data.pos.shape).to(data.pos.device)
            label = 1
        else:
           new_data = data
           label = 0

        return new_data, label
    

class DetectRotatedModelNetDataset(ModelNet):
    def __init__(self, root, split="train", transformed_class_balance=0.5, apply_random_rot_frac=0, canonicalize=False, trivial_test=False, fixed_transform=True, fixed_rotation=True, fixed_quaternions=None): #, quat_ids=None):
        
        # fixed_transform: fix which datapoints are from "transformed" dataset and which stay original
        # fixed_rotation: whenever a rotation happens, is it going to be the same rotation always for that datapoint, or drawn independently at each query time?

        super().__init__(root)
        self.split = split

        self.apply_random_rot_frac = apply_random_rot_frac
        print('RANDOM ROT: ', self.apply_random_rot_frac)

        self.transformed_class_balance = transformed_class_balance

        # generate random quaternions up here
        ds_len = super().__len__

        num_elts = ds_len()

        if fixed_transform:
            do_rotation = torch.rand(num_elts) < apply_random_rot_frac
            self.do_rotation = do_rotation
            transformed_mask = torch.rand(num_elts) < transformed_class_balance
            self.transformed_mask = transformed_mask
        else:
            self.transformed_mask = None
            self.do_rotation = None 
        if fixed_rotation and fixed_quaternions is None:
          fixed_quaternions = generate_many_random_quaternions(num_elts)

        self.fixed_quaternions = fixed_quaternions
        self.fixed_transform = fixed_transform 
        self.fixed_rotation = fixed_rotation 

        self.canonicalize = canonicalize # before doing anything, canonicalize the data

        self.trivial_test = trivial_test


    def __getitem__(self, idx, verbose=False):
        data = super().__getitem__(idx)

        if self.canonicalize:
            data = canonicalize_molecule(data)

        if self.fixed_transform:
            in_transformed_set = self.transformed_mask[idx]
            do_rotation = self.do_rotation[idx]
        else:
            in_transformed_set = (torch.rand(1) < self.transformed_class_balance)
            do_rotation = (torch.rand(1)) < self.apply_random_rot_frac

        if in_transformed_set:
            if do_rotation:
                if self.fixed_rotation:
                    rot_to_apply = self.fixed_quaternions[idx]
                else:
                    rot_to_apply = generate_random_quaternion()
                new_data = rotate_molecule(data, rot_to_apply)
                if self.trivial_test:
                    new_data.pos = 100*torch.ones(data.pos.shape).to(data.pos.device)
            label = 1
        else:
           new_data = data
           label = 0

        return new_data, label
    
class AugmentedQM9Dataset(QM9):
    def __init__(self, root, split="train", transformed_class_balance=0.5, apply_random_rot_frac=False, fixed_transform=True, fixed_rotation=True, fixed_quaternions=None): #, quat_ids=None):
        super().__init__(root)
        self.split = split

        self.apply_random_rot_frac = apply_random_rot_frac

        # generate random quaternions up here
        ds_len = super().__len__
        num_elts = ds_len()

        self.transformed_class_balance = transformed_class_balance
        self.fixed_rotation = fixed_rotation
        self.fixed_transform = fixed_transform
        self.fixed_quaternions = fixed_quaternions
        if fixed_rotation and self.fixed_quaternions is None:
           self.fixed_quaternions = generate_many_random_quaternions(ds_len())

        if self.fixed_transform:
            do_rotation = torch.rand(num_elts) < apply_random_rot_frac
            self.do_rotation = do_rotation
        self.means = self.data.y.mean(dim=0, keepdim=True)
        self.stds = self.data.y.std(dim=0, keepdim=True)

    # def normalize(self, arr):
    #     return normalize_outputs(arr, means=self.means, stds=self.stds)

    # def unnormalize(self, arr):
    #     return unnormalize_outputs(arr, means=self.means, stds=self.stds)

    def __getitem__(self, idx):
        data = super().__getitem__(idx)

        if self.fixed_transform:
            do_rotation = self.do_rotation[idx]
        else:
            do_rotation = (torch.rand(1)) < self.apply_random_rot_frac

        if do_rotation:
            if self.fixed_rotation:
                rot_to_apply = self.fixed_quaternions[idx]
            else:
                rot_to_apply = generate_random_quaternion()
            data = rotate_molecule(data, rot_to_apply)
        label = (data.y[0, 4:16] - self.means[0, 4:16]) / self.stds[0, 4:16]
        return data, label


# Classification task, without label change
class AugmentedModelNetDataset(ModelNet):
    def __init__(self, root, split="train", transformed_class_balance=0.5, apply_random_rot_frac=False, fixed_transform=True, fixed_rotation=True, fixed_quaternions=None): #, quat_ids=None):
        super().__init__(root)
        self.split = split

        self.apply_random_rot_frac = apply_random_rot_frac

        # generate random quaternions up here
        ds_len = super().__len__
        num_elts = ds_len()

        self.transformed_class_balance = transformed_class_balance
        self.fixed_rotation = fixed_rotation
        self.fixed_transform = fixed_transform
        self.fixed_quaternions = fixed_quaternions
        if fixed_rotation and self.fixed_quaternions is None:
           self.fixed_quaternions = generate_many_random_quaternions(ds_len())

        if self.fixed_transform:
            do_rotation = torch.rand(num_elts) < apply_random_rot_frac
            self.do_rotation = do_rotation
        # self.means = self.data.y.mean(dim=0, keepdim=True)
        # self.stds = self.data.y.std(dim=0, keepdim=True)

    # def normalize(self, arr):
    #     return normalize_outputs(arr, means=self.means, stds=self.stds)

    # def unnormalize(self, arr):
    #     return unnormalize_outputs(arr, means=self.means, stds=self.stds)

    def __getitem__(self, idx):
        data = super().__getitem__(idx)

        if self.fixed_transform:
            do_rotation = self.do_rotation[idx]
        else:
            do_rotation = (torch.rand(1)) < self.apply_random_rot_frac

        if do_rotation:
            if self.fixed_rotation:
                rot_to_apply = self.fixed_quaternions[idx]
            else:
                rot_to_apply = generate_random_quaternion()
            data = rotate_molecule(data, rot_to_apply)
        label = data.y
        return data, label


class DetectRotatedMD17Dataset(MD17):
    def __init__(self, root, molecule="aspirin", split="train", rotated_class_balance=0.5, apply_random_rot_frac=False, canonicalize=False):
        super().__init__(root, name=molecule)
        self.split = split
        self.rotated_class_balance = rotated_class_balance
        self.apply_random_rot_frac = apply_random_rot_frac
        self.canonicalize = canonicalize

        # Precompute the dataset with rotations during initialization
        print('Initializing dataset...')
        self.preprocessed_data = []
        for idx in range(super().__len__()):
            data = super().__getitem__(idx)
            if self.canonicalize:
                data = canonicalize_molecule(data)

            if torch.rand(1)[0] < self.rotated_class_balance:
                if self.apply_random_rot_frac > 0 and torch.rand(1)[0] < self.apply_random_rot_frac:
                    data = rotate_molecule(data, generate_random_quaternion())
                    label = 1
                else:
                    label = 0
            else:
                label = 0

            self.preprocessed_data.append((data, label))
        print('Dataset initialized!')

    def __getitem__(self, idx):
        return self.preprocessed_data[idx]

    def __len__(self):
        return len(self.preprocessed_data)
    
class AugmentedMD17Dataset(MD17):
    def __init__(self, root, split="train", molecule = "aspirin",transformed_class_balance=0.5, apply_random_rot_frac=False, fixed_transform=True, fixed_rotation=True, fixed_quaternions=None): #, quat_ids=None):
        super().__init__(root,name=molecule)
        self.split = split
        self.molecule = molecule
        self.apply_random_rot_frac = apply_random_rot_frac

        # generate random quaternions up here
        ds_len = super().__len__
        num_elts = ds_len()

        self.transformed_class_balance = transformed_class_balance
        self.fixed_rotation = fixed_rotation
        self.fixed_transform = fixed_transform
        self.fixed_quaternions = fixed_quaternions
        if fixed_rotation and self.fixed_quaternions is None:
           self.fixed_quaternions = generate_many_random_quaternions(ds_len())

        if self.fixed_transform:
            do_rotation = torch.rand(num_elts) < apply_random_rot_frac
            self.do_rotation = do_rotation

        self.means = self.data.energy.mean(dim=0, keepdim=True)
        self.stds = self.data.energy.std(dim=0, keepdim=True)

    #def normalize(self, arr):
    #    return normalize_outputs(arr, means=self.means, stds=self.stds)

    #def unnormalize(self, arr):
    #    return unnormalize_outputs(arr, means=self.means, stds=self.stds)

    def __getitem__(self, idx):
        data = super().__getitem__(idx)

        if self.fixed_transform:
            do_rotation = self.do_rotation[idx]
        else:
            do_rotation = (torch.rand(1)) < self.apply_random_rot_frac

        if do_rotation:
            if self.fixed_rotation:
                rot_to_apply = self.fixed_quaternions[idx]
            else:
                rot_to_apply = generate_random_quaternion()
            data = rotate_molecule(data, rot_to_apply)

        label = (data.energy - self.means) / self.stds 

        return data, label

# to edit
class ToyCircleDataset(Dataset):
    def __init__(self, root, num_samples=1000, radius=1.0, mode="circle", regen=False):
        """
        Args:
            num_samples (int): Number of points to generate.
            radius (float): Maximum radius for random mode, fixed radius for circular mode.
            mode (str): "circle" to place points on a circle, "random" to spread within radius.
        """
        self.num_samples = num_samples
        self.radius = radius
        self.mode = mode.lower()

        save_name = os.path.join(root, f'toy_circle-mode_{mode}-num_samples_{num_samples}-radius_{radius}.pt')

        if os.path.exists(save_name) and not regen:
            loaded = torch.load(save_name)
            self.data = loaded['data']
            self.labels = loaded['labels']
        else:
            self.data, self.labels = self.generate_data()
            # re-save: 
            self.save(save_name)

    def generate_data(self):
        if self.mode == "circle":
            angles = 2 * np.pi * np.random.rand(self.num_samples)
            x = self.radius * np.cos(angles)
            y = self.radius * np.sin(angles)
        elif self.mode == "random":
            r = self.radius * np.sqrt(np.random.rand(self.num_samples))
            angles = 2 * np.pi * np.random.rand(self.num_samples)
            x = r * np.cos(angles)
            y = r * np.sin(angles)
        else:
            raise ValueError("Mode must be 'circle' or 'random'.")

        points = np.column_stack((x, y))
        labels = (x > 0).astype(int)  # 0 for x < 0, 1 for x > 0
        return torch.tensor(points, dtype=torch.float32), torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

    def save(self, filename):
        torch.save({"data": self.data, "labels": self.labels}, filename)

    # not used??
    @staticmethod
    def load(filename):
        loaded = torch.load(filename)
        dataset = ToyCircleDataset()
        dataset.data = loaded["data"]
        dataset.labels = loaded["labels"]
        dataset.num_samples = len(dataset.data)
        return dataset

    def visualize(self, num_points=500):
        num_points = min(num_points, self.num_samples)
        x, y = self.data[:num_points, 0].numpy(), self.data[:num_points, 1].numpy()
        labels = self.labels[:num_points].numpy()

        plt.figure(figsize=(6, 6))
        plt.scatter(x[labels == 0], y[labels == 0], color='blue', label="Class 0 (x < 0)")
        plt.scatter(x[labels == 1], y[labels == 1], color='red', label="Class 1 (x > 0)")
        plt.axvline(0, color='black', linestyle='--', linewidth=1)  # Decision boundary
        plt.legend()
        plt.title(f"Binary Classification Dataset ({self.mode} mode)")
        plt.show()


class SwissRollDataset(Dataset):
    def __init__(self, root, prob=1.0, z_relate_to_xy=True, max_r=0.1, num_samples=1000, regen=False):
        """
        Args:
            root (str): path to save the data
            prob (float): control the probability of flipping z value
            z_relate_to_xy (bool): if True, flip z value based on radial distance; if False, flip z value randomly
            max_r (float): the maximum radius of the swiss roll
            num_samples (int): the number of points to sample from the swiss roll to construct the dataset
            regen (bool): whether to resample data even if saved data exists
        """
        assert 0 <= prob <= 1
        self.prob = prob
        self.z_relate_to_xy = z_relate_to_xy
        self.max_r = max_r
        self.num_samples = num_samples
        self.num_points = 100000

        save_name = os.path.join(root, f'swiss_roll-prob_{prob}-z_relate_to_xy_{z_relate_to_xy}-num_samples_{num_samples}-max_r_{max_r}.pt')

        # hack to make consistent (HL)
        if self.z_relate_to_xy:
            print('z relates to xy??')
            self.prob = 1 - self.prob

        if os.path.exists(save_name) and not regen:
            loaded = torch.load(save_name)
            self.data = loaded['data']
            self.labels = loaded['labels']
        else:
            self.data, self.labels = self.generate_data()
            self.save(save_name)

    def generate_data(self, plot=True):
        r = np.arange(0, self.max_r, self.max_r/self.num_points)[:self.num_points]
        assert len(r) == self.num_points
        
        theta = 90 * np.pi * r
        yy_1 = 10 * r * np.sin(theta)
        xx_1 = 10 * r * np.cos(theta)
        data_1 = np.stack((xx_1, yy_1, np.ones_like(xx_1) * 0), 1)

        yy_2 = 10 * r * np.sin(theta + 3)
        xx_2 = 10 * r * np.cos(theta + 3)
        data_2 = np.stack((xx_2, yy_2, np.ones_like(xx_2) * 0), 1)

        if self.z_relate_to_xy:
            flip_r = self.prob * self.max_r
            flip_idx = np.where(r >= flip_r)[0]
            data_2[flip_idx, 2] = 1
        else:
            permutation = np.arange(len(r))
            n_chunk = 100
            chunks = np.split(permutation, n_chunk)
            n_flip_chunk = int(n_chunk * self.prob)
            if n_flip_chunk > 0:
                selected_chunk_indices = np.random.choice(len(chunks), n_flip_chunk, replace=False)
                selected_chunks = np.concatenate([chunks[i] for i in selected_chunk_indices])
                data_2[selected_chunks, 2] = 1

        data_1 = torch.tensor(data_1)
        label_1 = torch.ones(data_1.shape[0]) * 0
        data_2 = torch.tensor(data_2)
        label_2 = torch.ones(data_2.shape[0]) * 1

        #breakpoint()

        
        if plot:
            fig = plt.figure(dpi=300, figsize=(5, 5))
            plt.plot(data_1[:, 0], data_1[:, 1], 'o', color='orange', markersize=0.1)
            plt.plot(data_2[:, 0], data_2[:, 1], 'o', color='blue', markersize=0.1)
            plt.tight_layout()
            plt.savefig("swiss_roll_2d.png")
            plt.show()

            fig = plt.figure(dpi=300)
            ax = fig.add_subplot(111, projection='3d')
            ax.scatter(data_1[:, 0], data_1[:, 1], data_1[:, 2], color='orange', s=0.1)
            ax.scatter(data_2[:, 0], data_2[:, 1], data_2[:, 2], color='blue', s=0.1)
            ax.set_xlabel('x')
            ax.set_ylabel('y')
            ax.set_zlabel('z')
            ax.set_xticks(np.linspace(-self.max_r*10, self.max_r*10, 5))
            ax.set_yticks(np.linspace(-self.max_r*10, self.max_r*10, 5))
            ax.set_zticks([0, 1])
            ax.set_xlim(-self.max_r*10, self.max_r*10)
            ax.set_ylim(-self.max_r*10, self.max_r*10)
            plt.tight_layout()
            plt.savefig("swiss_roll_3d.png")
            plt.show()


        data = torch.cat([data_1, data_2], dim=0).float()
        label = torch.cat([label_1, label_2], dim=0).long()

        permutation = torch.tensor(np.random.permutation(data.shape[0]))
        data = data[permutation]
        label = label[permutation]

        data = data[:self.num_samples]
        label = label[:self.num_samples]
        
        return data, label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

    def save(self, filename):
        torch.save({"data": self.data, "labels": self.labels}, filename)

    @staticmethod
    def load(filename):
        loaded = torch.load(filename)
        dataset = SwissRollDataset.__new__(SwissRollDataset)
        dataset.data = loaded["data"]
        dataset.labels = loaded["labels"]
        return dataset

    
