from torch.utils.data import DataLoader, TensorDataset, random_split
import matplotlib.pyplot as plt
import numpy as np
import pickle
import time
import copy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.datasets import QM9, MD17, ModelNet
from datasets import QM9AtomicPyG
# OC20 dataset, note you need to install fairchem
import torch_geometric
from torch_geometric.data import Batch
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from torch_geometric.utils import add_self_loops, degree, to_dense_batch
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split
import time
from torch.utils.data import Subset
from pathlib import Path
import torch_geometric.transforms as T

import utils 
import models
from datasets import MetaPairDetectionDataset,MetaDetectionDataset, MetaAugmentedDataset, ModelNetDataset, LocalQM9Dataset, EfficientLocalQM9, ToyCircleDataset, SwissRollDataset, AugmentedModelNetDataset, QM7TensorDataset #RotatedQM9Dataset, DetectRotatedQM9Dataset, AugmentedQM9Dataset
#RotatedQM9Dataset, DetectRotatedQM9Dataset, AugmentedQM9Dataset

from train import train_model
import hydra
import omegaconf
from omegaconf import DictConfig, OmegaConf
import random
import os
import wandb
from datetime import datetime
import pandas as pd
from ctypes import cdll
from argparse import Namespace
import torchvision
from torchvision import transforms

def set_seed(seed=42):
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    random.seed(seed)

    np.random.seed(seed)

    # CuDNN for deterministic behavior (optional)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

libcudart = cdll.LoadLibrary('libcudart.so')

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

data_dir = os.environ['DATA_DIR']

def get_fixed_quaternions():
    # load fixed quaternions if already computed, otherwise generate and save them
    fixed_quaternions_file_name = os.path.join(data_dir, 'fixed_quaternions.pkl')
    if os.path.exists(fixed_quaternions_file_name):
        with open(fixed_quaternions_file_name, 'rb') as f:
            fixed_quaternions = pickle.load(f)
    else:
        fixed_quaternions = utils.generate_many_random_quaternions(627983) #130831)
        with open(fixed_quaternions_file_name, 'wb') as f:
            pickle.dump(fixed_quaternions, f)   
    return fixed_quaternions

def get_transform_operator(dataset_cfg, return_g=False):
    """
    Returns a transformation operator function based on the dataset configuration.
    The transformation operator is a callable that takes a data sample and its index, and returns a transformed (alternatively, augmented) version of the data. The specific transformation depends on the dataset specified in `dataset_cfg`. We do not assume that the label is explicitly given to the transform operator! It takes precisely whatever the base_dataset gives it (which may or may not include the label already)

    Parameters:
        dataset_cfg (Namespace or dict-like): Configuration object containing at least the attribute `name` to identify the dataset. May also include the attribute `randomized_transform` (bool) to control stochastic transformations.

    Returns:
        Callable or None: A function (data, idx) -> transformed_data, or None if the dataset is unrecognized or not yet supported (e.g., for 'swissroll').

    Notes:
        - For QM9, MD17, and OC20 datasets, returns either a randomized or memoized rotation using fixed quaternions.
        - For ToyCircle dataset, returns a lambda applying a random scaling to the first element.
        - For SwissRoll dataset, returns None (transform not implemented).
    """
    fixed_quaternions = get_fixed_quaternions() 
    dset_name = dataset_cfg.name 
    if 'qm9' in dset_name or dset_name == 'md17' or dset_name == 'oc20' or dset_name == 'modelnet':
        ### TODO MD17 forces and energies should be equivariant
        if "randomized_transform" in dataset_cfg and dataset_cfg.randomized_transform:
            transform_operator = (lambda data, idx: utils.randomized_rotation(data, idx, fixed_quaternions, return_g=return_g))
        else:
            transform_operator = (lambda data, idx: utils.memoized_rotation(data, idx, fixed_quaternions))
    # TODO implement rotation version for QM7, here we have vectors and rank 2 tensors!!!
    ### maybe would be best to just rotate everything for consistency and then we don't have to worry
    ### about passing in the target???
    elif 'qm7' in dset_name:
        if "randomized_transform" in dataset_cfg and dataset_cfg.randomized_transform:
            transform_operator = (lambda data, idx: utils.randomized_rotation_qm7(data, idx, fixed_quaternions))
        else:
            transform_operator = (lambda data, idx: utils.memoized_rotation_qm7(data, idx, fixed_quaternions))
    elif 'toycircle' in dset_name:
        transform_operator = (lambda data, idx: (np.random.rand() * data[0], data[1])) #, data[1]))
    elif 'swissroll' in dset_name:
        # data is batch x 3
        transform_operator = (lambda data, idx: utils.transform_swiss_roll(data, idx))
    elif 'MNIST' in dset_name:
        transform_operator = (lambda data, idx: (transforms.functional.rotate(data[0], random.choice([0, 90, 180, 270])), data[1]))
    return transform_operator

def get_label_operator(dataset_cfg, base_dataset, transform_operator=None):
    """
    Returns a labeling operator function based on the dataset configuration and base dataset.

    The returned operator standardizes access to labels for various datasets. It accepts exactly a dataset element—which may be a data object or a (data, label) pair, depending on the dataset—and returns a tuple (data, label), where `label` is extracted or computed according to the dataset-specific logic.

    Parameters:
        dataset_cfg (Namespace or dict-like): Configuration object containing at least
            the attribute `name` to identify the dataset.
        base_dataset (Dataset-like): The dataset instance used to compute dataset-specific
            statistics (e.g., mean and standard deviation for normalization), if needed.
        transform_operator (operator): only used for the predict_g task

    Returns:
        Callable: A labeling function that takes a dataset element (and possibly an index) - whatever is included in the dataset object - and returns a tuple (data, label). If detectiont ask, label_operator is None!

    Notes:
        - For QM9 and MD17 datasets, the label is normalized using dataset-specific mean and std values.
        - For OC20, the label operator may depend on both the data and its index.
        - For ToyCircle and SwissRoll datasets, the label is extracted from the data using utility functions.
    """
    dset_name = dataset_cfg.name 
    if dataset_cfg.task == 'detect_transform':
        return None
    
    if dataset_cfg.task == 'predict_g':
        transform_return_g_operator = get_transform_operator(dataset_cfg, return_g=True)
        label_operator = (lambda data: utils.get_label_g_from_gx(data, transform_return_g_operator))
        return label_operator 
    
    if 'qm9' in dset_name:
        if 'atomic' in dset_name:
            means, stds = utils.get_qm9_atomic_means_stds(base_dataset,dataset_cfg.target)
            label_operator = (lambda data: utils.get_label_from_qm9_element_atomic(data, means, stds, target = dataset_cfg.target))
        else:
            means, stds = utils.get_qm9_means_stds(base_dataset,dataset_cfg.target)
            label_operator = (lambda data: utils.get_label_from_qm9_element(data, means, stds,target = dataset_cfg.target))
    elif 'qm7' in dset_name:
        # the means and stds are not used for QM7, but we need to pass them in for consistency
        means, stds = utils.get_qm7_means_stds(base_dataset,dataset_cfg.target)
        if dataset_cfg.target in ['alpha_tensor', 'quadrupole']:
            # need to convert label to irreps for higher order tensors
            label_operator = (lambda data: utils.get_label_from_qm7_element(data, means, stds,target = dataset_cfg.target,to_irreps=True))
        else:
            label_operator = (lambda data: utils.get_label_from_qm7_element(data, means, stds,target = dataset_cfg.target,to_irreps=False))
            # label operator
    elif dset_name == 'md17':
        means, stds = utils.get_md17_means_stds(base_dataset)
        label_operator = (lambda data: utils.get_label_from_md17_element(data, means, stds))
    elif dset_name == 'modelnet':
        label_operator = (lambda data: utils.get_label_from_modelnet_element(data))
    elif dset_name == 'oc20':
        ### TODO: fix this, implementation is currently hacky
        #means, stds = utils.get_oc20_means_stds(base_dataset)
        label_operator = (lambda data, idx: utils.get_label_from_oc20_element(data, idx))
    elif 'toycircle' in dset_name:
        label_operator = (lambda data: utils.get_label_from_toy_circle_element(data)) # TODO: fix this to be consistent, include idx
    elif 'swissroll' in dset_name:
        label_operator = (lambda data: utils.get_label_from_swiss_roll_element(data))
    elif 'MNIST' in dset_name:
        label_operator = (lambda data: (data[0], data[1]))

    return label_operator

def get_canonicalize_operator(dataset_cfg):
    """
    Returns a canonicalization operator function based on the dataset configuration.

    The canonicalization operator applies a dataset-specific group transformation to
    bring the input data into a canonical form. This is typically used to reduce
    equivariant variability or standardize representation across samples.

    Parameters:
        dataset_cfg (Namespace or dict-like): Configuration object that must include the
            attribute `name` to identify the dataset. May also include `augment_args.canonicalize_method`
            to control the canonicalization strategy.

    Returns:
        Callable or None: A function (data, idx) -> canonicalized_data, or None if canonicalization
        is not supported for the given dataset.

    Notes:
        - For QM9, molecules are canonicalized using a method such as 'closest' (default).
        - For ToyCircle, the canonicalization method defaults to 'unit' unless specified otherwise.
        - For SwissRoll, canonicalization is supported and may use a specified method.
        - For MD17 and OC20 datasets, canonicalization is not currently implemented (returns None).
    """
    dset_name = dataset_cfg.name 
    if 'qm9' in dset_name or 'modelnet' in dset_name: 
        if 'canonicalize_method' in dataset_cfg.augment_args:
            method = dataset_cfg.augment_args.canonicalize_method
        else:
            method = 'PCA' #'PCA' #'closest'
        canonicalize_operator = (lambda data, idx: utils.canonicalize_molecule(data, method=method))
    elif 'qm7' in dset_name:
        #TODO implement multiple equivariant options for this/align with the dipole moment?
        if 'canonicalize_method' in dataset_cfg.augment_args:
            method = dataset_cfg.augment_args.canonicalize_method
        else:
            method = 'closest'
        canonicalize_operator = (lambda data, idx: utils.canonicalize_qm7(data, method=method))
    elif dset_name == 'md17':
        canonicalize_operator = None    
    elif dset_name == 'oc20':
        canonicalize_operator = None
    elif dset_name == 'toycircle':
        if 'canonicalize_method' in dataset_cfg.augment_args:
            method = dataset_cfg.augment_args.canonicalize_method
        else:
            method = 'unit'
        canonicalize_operator = (lambda data, idx: utils.canonicalize_toy_circle(data, method=method))
    elif dset_name == 'swissroll':
        if 'canonicalize_method' in dataset_cfg.augment_args:
            method = dataset_cfg.augment_args.canonicalize_method
        canonicalize_operator = (lambda data, idx: utils.canonicalize_swiss_roll(data, idx, method=method))
    else:
        canonicalize_operator = None # not implemented for MNIST
    return canonicalize_operator

def get_dataset(cfg, dataset_dir):
    dataset_cfg = cfg.dataset
    dset_name = dataset_cfg.name 
    dset_task = dataset_cfg.task
    detection_args = dataset_cfg.detection_args
    augment_args = dataset_cfg.augment_args

    transform_operator = get_transform_operator(dataset_cfg)

    if 'qm9' in dset_name: 
        if 'local' in dset_name:
            print('dataset_dir', dataset_dir)
            base_dataset = EfficientLocalQM9(root=dataset_dir, specific_file=dataset_cfg.specific_file)
            print(f'done. base_dataset has length {len(base_dataset)}. loaded from {dataset_cfg.specific_file}. now about to get means')
            means, stds = utils.get_qm9_means_stds(base_dataset,dataset_cfg.target)
            print('done')
        elif 'atomic' in dset_name:
            p = Path(dataset_dir)
            dataset_dir = p.with_name(p.name + '_atomic')
            os.makedirs(dataset_dir, exist_ok=True)
            base_dataset = QM9AtomicPyG(root_dir=dataset_dir)
            means, stds = utils.get_qm9_atomic_means_stds(base_dataset,dataset_cfg.target)
        else:
            base_dataset = QM9(root=dataset_dir)
            means, stds = utils.get_qm9_means_stds(base_dataset,dataset_cfg.target)
    elif dset_name == 'modelnet':
        # mol_name ="40" or "10"  samplesize = 1024
        transform = T.Compose([
                T.NormalizeScale(),        
                T.SamplePoints(dataset_cfg.num_sample_points, include_normals=False),      
            ])
        base_dataset = ModelNetDataset(root=dataset_dir, name=dataset_cfg.mol_name, transform=transform, classes=dataset_cfg.classes)
        print(len(base_dataset))
    elif 'qm7' in dset_name:
        base_dataset = QM7TensorDataset(file_dir=dataset_dir,filter_polar = dataset_cfg.filter_polar)
        means, stds = utils.get_qm7_means_stds(base_dataset, dataset_cfg.target)
        # TODO do we need to standardize? not sure
    elif dset_name == 'md17':
        base_dataset = MD17(root=dataset_dir, name=dataset_cfg.mol_name)
    elif dset_name == 'oc20':
        from fairchem.core.datasets import LmdbDataset
        base_dataset = LmdbDataset({'src':dataset_dir+'/s2ef/200k/train/'})
    elif dset_name == 'toycircle':
        base_dataset = ToyCircleDataset(root=dataset_dir, num_samples=cfg.dataset.num_samples_base, mode="circle")
    elif dset_name == 'swissroll':
        base_dataset = SwissRollDataset(root=dataset_dir, num_samples=cfg.dataset.num_samples_base, prob=cfg.dataset.prob, z_relate_to_xy=cfg.dataset.z_relate_to_xy, max_r=cfg.dataset.max_r, regen=cfg.dataset.regen)
    elif dset_name == 'MNIST':
        mnist_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))  # mean and std for MNIST
        ])
        base_dataset = torchvision.datasets.MNIST(root=dataset_dir, train=True, download=True, transform=mnist_transform) # split up the training set, it is fine for a simple example

    print('length of base dataset: ', len(base_dataset))

    label_operator = get_label_operator(dataset_cfg, base_dataset, transform_operator=transform_operator)
    canonicalize_operator = get_canonicalize_operator(dataset_cfg)
        
    # if dataset_cfg.task == 'detect_transform': # not sure if this is needed? try commenting out
    #     label_operator = None

    # get train/test indices
    # is this the difference?? 
    if 'qm9' in dset_name and 'local' not in dset_name and 'anderson' in dataset_cfg.split_args.split_type:
        absolute_train_indices, absolute_val_indices, absolute_test_indices = get_train_test_val_qm9_anderson_splits(dataset_dir, len(base_dataset), random_state=cfg.training.split_seed)
    else:
        full_train_indices, absolute_test_indices = get_train_test_inds(len(base_dataset), dataset_cfg, random_state=cfg.training.split_seed)

        # furthermore split train into train/validation
        absolute_train_indices, absolute_val_indices = get_validation_inds(len(full_train_indices), full_train_indices=full_train_indices, dataset_cfg=dataset_cfg, random_state=cfg.training.split_seed)

    absolute_inds = {'train': absolute_train_indices, 'val': absolute_val_indices, 'test': absolute_test_indices}
    if augment_args.do_augment:
        transform_list = []
        transform_first = getattr(augment_args, 'transform_first', False)

        canonicalize_index = None  # track this
        transform_index = None     # optional, if you want to reference later

        if augment_args.canonicalize > 0:
            canonicalize_index = len(transform_list)
            transform_list.append((canonicalize_operator, augment_args.canonicalize))

        if augment_args.transform > 0:
            transform_index = len(transform_list)
            transform_list.append((transform_operator, augment_args.transform))

        if transform_first and len(transform_list) == 2:
            transform_list.reverse()
            # Update indices accordingly
            canonicalize_index, transform_index = transform_index, canonicalize_index

        processed_dataset = MetaAugmentedDataset(base_dataset, transform_list, label_operator=label_operator)

        for split in ['train', 'val', 'test']:
            if not getattr(augment_args, split):
                print(f"Turning off {split} augmentation (except canonicalization)")
                # If canonicalization is None, exclude=[] (no-op)
                # Exclude canonicalization from being turned off if canonicalize_all is True
                # Recalculate exclude inside the loop
                exclude = []
                if getattr(augment_args, 'canonicalize_all', False) and canonicalize_index is not None:
                    exclude = [canonicalize_index]
                    print(" (except canonicalization)")
                else:
                    print()

                processed_dataset.turn_off_transforms(
                    absolute_inds[split],
                    exclude_indices=exclude
                )
    else:
        processed_dataset = MetaAugmentedDataset(base_dataset, transform_list=[], label_operator=label_operator) #base_dataset 

    if dataset_cfg.task == 'detect_transform':
        transformed_class_balance = detection_args.transformed_class_balance 
        dataset = MetaDetectionDataset(processed_dataset, transform_operator, transformed_class_balance=transformed_class_balance)
    elif dataset_cfg.task == 'task_dependent':
        task_dependent_args = dataset_cfg.task_dependent_args
        if task_dependent_args.binary_detection:
            dataset = MetaPairDetectionDataset(processed_dataset,  mismatched_pair_balance=task_dependent_args.mismatched_pair_balance)
        else:
            dataset = processed_dataset
    else:
        dataset = processed_dataset # also used for task_dependent but not binary (predict y(x) from c(x))

    if dset_task == 'detect_transform' or (dset_task == 'task_dependent' and (dataset_cfg.task_dependent_args.binary_detection or ('modelnet' in dset_name))):
        criterion = nn.CrossEntropyLoss()
        aux_criteria = {'accuracy': utils.classification_loss}
        if dset_task == 'task_dependent' and ('modelnet' in dset_name):
            aux_criteria = {
    'accuracy': (lambda outputs, targets: 100 * (outputs.argmax(1) == targets).sum().item() / targets.size(0)),
    **{f'accuracy_per_class_{label}': (lambda outputs, targets, label=label: 
        100 * ((outputs.argmax(1) == targets) & (targets == label)).sum().item() / (targets == label).sum().item()
        if (targets == label).sum().item() > 0 else 0.0)
       for label in range(dataset_cfg.num_classes)}
}  
        # TODO: add per_class acc for binary detection
    else:
        criterion = nn.MSELoss() # could make this smarter for group elements in the future
        if 'qm9' in dset_name: 
            # for predict_g task, still ok to track entrywise error via MAE! 
            if dset_task == 'predict_g':
                angle_loss = (lambda outputs, targets: utils.rotation_error(outputs, targets, to_float=False))
                if 'angle_loss' in dataset_cfg and dataset_cfg.angle_loss:
                    criterion = angle_loss
                aux_criteria = {'MAE': (lambda outputs, targets: utils.qm9_regression_mae(outputs, targets, means=None, stds=None, normalize=False)), 'rot_angle': (lambda outputs, targets: utils.rotation_error(outputs, targets, to_float=True))}
            else:
                aux_criteria = {'MAE': (lambda outputs, targets: utils.qm9_regression_mae(outputs, targets, means[:, :], stds[:, :]))}
            val = means[0]
            formatted_means = ", ".join(f"{x:.2f}" for x in val)
            val = stds[0]
            formatted_stds = ", ".join(f"{x:.2f}" for x in val)
            print('means', formatted_means)
            print('stds', formatted_stds)
        elif 'qm7' in dset_name:
            # per component logging for dipole (would need to change for other tensors)
            if dataset_cfg.target == 'dipole':
                aux_criteria = {
                    'MAE_x': (lambda outputs, targets: utils.qm7_regression_mae(outputs, targets, means, stds)[0].item()),
                    'MAE_y': (lambda outputs, targets: utils.qm7_regression_mae(outputs, targets, means, stds)[1].item()),
                    'MAE_z': (lambda outputs, targets: utils.qm7_regression_mae(outputs, targets, means, stds)[2].item()),
                    'MAE_mean_component': (lambda outputs, targets: torch.mean(utils.qm7_regression_mae(outputs, targets, means, stds)).item()),
                    'MAE_vector_norm': (lambda outputs, targets: torch.mean(torch.norm(outputs - targets, dim=1)).item())
                }
            elif dataset_cfg.target == 'alpha_tensor' or dataset_cfg.target == 'quadrupole':
                # stored in xx,yy,zz,xy,xz,yz format
                aux_criteria = {
                    'MAE_xx': (lambda outputs, targets: utils.qm7_regression_mae(outputs, targets, means, stds,from_irreps=True)[0].item()),
                    'MAE_yy': (lambda outputs, targets: utils.qm7_regression_mae(outputs, targets, means, stds,from_irreps=True)[1].item()),
                    'MAE_zz': (lambda outputs, targets: utils.qm7_regression_mae(outputs, targets, means, stds,from_irreps=True)[2].item()),
                    'MAE_xy': (lambda outputs, targets: utils.qm7_regression_mae(outputs, targets, means, stds,from_irreps=True)[3].item()),
                    'MAE_xz': (lambda outputs, targets: utils.qm7_regression_mae(outputs, targets, means, stds,from_irreps=True)[4].item()),
                    'MAE_yz': (lambda outputs, targets: utils.qm7_regression_mae(outputs, targets, means, stds,from_irreps=True)[5].item()),
                    'MAE_mean_component': (lambda outputs, targets: torch.mean(utils.qm7_regression_mae(outputs, targets, means, stds,from_irreps=True)).item()),
                    'MAE_tensor_norm': (lambda outputs, targets: torch.mean(torch.norm(utils.get_e3nn_matrix(outputs,outputs.device,to_irreps=False) - utils.get_e3nn_matrix(targets,targets.device,to_irreps=False), dim=1)).item())
                }
            elif dataset_cfg.target == 'alpha_iso' or dataset_cfg.target == 'alpha_aniso':
                aux_criteria = {
                    'MAE': (lambda outputs, targets: utils.qm7_regression_mae(outputs, targets, means[:, :], stds[:, :],target_type="alpha_iso").item())
                }

        elif 'toycircle' in dset_name or 'swissroll' in dset_name or 'MNIST' in dset_name or 'modelnet' in dset_name:
            criterion = nn.CrossEntropyLoss()
            aux_criteria = {'accuracy': (lambda outputs, targets: (outputs.argmax(1) == targets).sum().item() / (targets.size(0)))}
            if dataset_cfg.task == 'classification':
                aux_criteria = {
    'accuracy': (lambda outputs, targets: 100 * (outputs.argmax(1) == targets).sum().item() / targets.size(0)),
    **{f'accuracy_per_class_{label}': (lambda outputs, targets, label=label: 
        100 * ((outputs.argmax(1) == targets) & (targets == label)).sum().item() / (targets == label).sum().item()
        if (targets == label).sum().item() > 0 else 0.0)
       for label in range(dataset_cfg.num_classes)}}

    return dataset, criterion, aux_criteria, absolute_inds, label_operator

def check_splits(train_indices, val_indices, test_indices, verbose=True):
    if verbose:
        num = 10
        print('First few train', train_indices[0:num])
        print('First few val', val_indices[0:num])
        print('First few test', test_indices[0:num])
    
    train_set = set(train_indices)
    val_set = set(val_indices)
    test_set = set(test_indices)
    # Check for overlaps
    train_val_overlap = train_set & val_set
    train_test_overlap = train_set & test_set
    val_test_overlap = val_set & test_set

    if train_val_overlap or train_test_overlap or val_test_overlap:
        print("❌ Overlap detected between splits!")
        if train_val_overlap:
            print(f"Overlap between train and val: {train_val_overlap}")
        if train_test_overlap:
            print(f"Overlap between train and test: {train_test_overlap}")
        if val_test_overlap:
            print(f"Overlap between val and test: {val_test_overlap}")
        return False

    # Check total length
    total_length = len(train_indices) + len(val_indices) + len(test_indices)
    print(f'Total length: {total_length}')

    print("✅ Splits are valid: No overlaps and correct total size.")
    return True

def get_train_test_inds(dataset_len, dataset_cfg, random_state=0):

    if dataset_cfg.split_args.use_full_dataset:
        train_indices, test_indices = train_test_split(range(dataset_len), test_size=dataset_cfg.split_args.test_prop, random_state=random_state) 
    else:
        # switching this ordering for backwards compatability with split args config not having use_predefined_splits field?
        if "use_predefined_splits" in dataset_cfg.split_args and dataset_cfg.split_args.use_predefined_splits:
            test_indices_df = pd.read_csv(dataset_cfg.split_args.test_indices, header=None)
            test_indices = test_indices_df[0].tolist()
            train_indices_df = pd.read_csv(dataset_cfg.split_args.train_indices, header=None)
            train_indices = train_indices_df[0].tolist()
        else:
            random.seed(random_state)
            indices = random.sample(range(dataset_len), dataset_cfg.split_args.subset_size)
            train_indices, test_indices = train_test_split(indices, test_size=dataset_cfg.split_args.test_prop, random_state=random_state)

    return train_indices, test_indices

def get_validation_inds(train_size, full_train_indices, dataset_cfg, random_state=0):

    train_indices, val_indices = train_test_split(range(train_size), test_size=dataset_cfg.split_args.validation_args.val_prop, random_state=random_state)

    absolute_train_indices = [full_train_indices[i] for i in train_indices]
    absolute_val_indices = [full_train_indices[i] for i in val_indices]

    return absolute_train_indices, absolute_val_indices

def get_train_test_val_qm9_anderson_splits(dataset_dir, dataset_len, random_state=0):
    # Anderson splits for QM9 dataset
    # 0.1% of the data is used for testing, 0.1% for validation, and the rest for training
    # returns training, validation and test indices
    # The test and validation sets are chosen to be disjoint from the training set
    # The test set is chosen to be disjoint from the training set
    # The validation set is chosen to be disjoint from the training set
    # adapted from https://github.com/atomicarchitects/datasets/blob/main/atomic_datasets/datasets/qm9.py#L233

    # atomic_datasets doesn't remove uncharacterized molecules automatically so need this

    included_idxs, excluded_idxs = utils.remove_uncharacterized_molecules(dataset_dir)

    # Now, generate random permutations to assign molecules to training/valation/test sets.
    Nmols = len(included_idxs)
    Ntrain = 100000
    Ntest = int(0.1 * Nmols)
    Nval = Nmols - (Ntrain + Ntest)

    # Generate random permutation.
    np.random.seed(random_state)
    data_permutation = np.random.permutation(Nmols)

    train_indices, val_indices, test_indices, extra_indices = np.split(
        data_permutation, [Ntrain, Ntrain + Nval, Ntrain + Nval + Ntest]
    )

    assert len(extra_indices) == 0, (
        f"Split was inexact {len(train_indices)} {len(val_indices)} {len(test_indices)} with {len(extra_indices)} extra."
    )

    train_indices = included_idxs[train_indices]
    val_indices = included_idxs[val_indices]
    test_indices = included_idxs[test_indices]

    return train_indices, val_indices, test_indices

def get_datasets_from_inds(absolute_inds, dataset):
    train_set = torch.utils.data.Subset(dataset, absolute_inds['train'])
    val_set = torch.utils.data.Subset(dataset, absolute_inds['val'])
    test_set = torch.utils.data.Subset(dataset, absolute_inds['test'])
    return train_set, val_set, test_set

def get_loaders(datasets, batch_size, num_workers, dataset_type):
    if 'qm9' in dataset_type or dataset_type == 'md17' or dataset_type == 'oc20' or 'qm7' in dataset_type or dataset_type == 'modelnet':
        dataloader_class = torch_geometric.loader.DataLoader
    else:
        dataloader_class = torch.utils.data.DataLoader
    dataloaders = {}
    for ky, dset in datasets.items():
        if ky == 'train':
            shuffle = True
        else:
            shuffle = False
        dataloaders[ky] = dataloader_class(dset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    return dataloaders

def get_model(model_cfg):
    model_class = getattr(models, model_cfg.name)
    model = model_class(**model_cfg.model_args)

    if "E3" in model_cfg.name:
        model = torch.compile(model,fullgraph=True,dynamic=True)
    
    return model.to(device)

def get_canon_model(cfg):
    # return c(.) module
    # can be an equivariant network
    # can be a fixed function

    model = get_model(cfg.dataset.task_dependent_args.c_args)
    if cfg.dataset.task_dependent_args.c_args.init_nz:
        model.apply(utils.init_weights)
    return model


def get_distance_func(cfg, separate_by_label=True):
    # distance_func takes as input two batches and outputs a distance
    # it can optionally shuffle the two batches
    # or it can sift out the different labels and make them into two different ones?

    pval_method = cfg.pvalue.method
    from baselines import chamfer_kernel, compute_mmd, hausdorff_kernel, stupid_kernel
    if 'mmd' in pval_method:
        if 'chamfer' in pval_method:
            individual_distance_function = chamfer_kernel
        elif 'hausdorff' in pval_method:
            individual_distance_function = hausdorff_kernel #hausdorff_distance
        elif 'stupid' in pval_method:
            individual_distance_function = stupid_kernel
        def dist_func(batch0, batch1):
            list0 = batch0[0].to_data_list()
            list1 = batch1[0].to_data_list()

            all_data = list0 + list1
            all_labels = torch.cat((batch0[1], batch1[1]))

            total_len = len(list0) + len(list1)

            if separate_by_label:
                compare0inds = torch.where(all_labels == 0)[0]
                compare1inds = torch.where(all_labels == 1)[0]

                # NEED TO MAKE EQUAL LENGTHS!  
                ln = min(int(torch.sum(compare0inds)), int(torch.sum(compare1inds)))
            else:
                randperm = random.sample(range(total_len), total_len)
                compare0inds = randperm[0:(total_len // 2)]
                compare1inds = randperm[(total_len // 2):]
            
            compare0inds = torch.tensor(compare0inds)
            compare1inds = torch.tensor(compare1inds)
            
            compare0 = [all_data[i] for i in compare0inds] # previously .pos
            compare1 = [all_data[i] for i in compare1inds]

            print('compare0', len(compare0), 'compare1', len(compare1))

            batch0, batch1 = Batch.from_data_list(compare0), Batch.from_data_list(compare1)

            if cfg.dataset.name == 'qm9' or cfg.dataset.name == 'oc20':
                pos0, mask0 = utils.preprocess(batch0, dataset_type="qm9_mmd")
                pos1, mask1 = utils.preprocess(batch1, dataset_type="qm9_mmd")
                if len(pos0.shape) > 3:
                    breakpoint()
                total = compute_mmd(pos0, pos1, mask0, mask1, kernel_func=individual_distance_function)
            else:
                pos0 = torch.stack([all_data[i].pos for i in compare0inds])
                pos1 = torch.stack([all_data[i].pos for i in compare1inds])
                mask0 = None
                mask1 = None
                # I'm not sure if this works
                total = compute_mmd(pos0, pos1, kernel_func=individual_distance_function) 
            return total

    return dist_func

def compute_p_value(datasets, cfg, save_dir, criterion, aux_criteria):
    # datasets have 'train', 'val', 'test'
    # datasets are assumed to be Subset objects, for MetaDetectionDataset objects!!!!

    # criterion and aux_criteria are not used for baseline methods

    # distance_func takes as input two batches and outputs a distance

    pval_args = cfg.pvalue
    pval_method = cfg.pvalue.method

    if pval_method != 'classifier':
        distance_func_for_calibration = get_distance_func(cfg, separate_by_label=False)
        distance_func_for_actual = get_distance_func(cfg, separate_by_label=True)

    train_dataset_size = len(datasets['train'])
    test_dataset_size = len(datasets['test'])

    # ------- COMPUTING CALIBRATION DISTANCE -------

    # set all loaders so that they rotate everything!
    datasets['train'].dataset.set_rotate_everything(True)
    datasets['val'].dataset.set_rotate_everything(True) # ????
    datasets['test'].dataset.set_rotate_everything(True)
    calibration_dists = []
    calibration_sample_size = pval_args.calibration_args.sample_size
    train_sample_size = int(pval_args.calibration_args.train_ratio * calibration_sample_size)
    test_sample_size = calibration_sample_size - train_sample_size
    for sample_ind in range(pval_args.calibration_args.num_samples):

        train_inds = random.sample(range(train_dataset_size), train_sample_size)
        test_inds = random.sample(range(test_dataset_size), test_sample_size)

        train_set_for_calib = Subset(datasets['train'], train_inds)
        test_set_for_calib = Subset(datasets['test'], test_inds) 

        start = time.time()
        if pval_method == 'classifier':
            dataloaders = get_loaders({'train': train_set_for_calib, 'test': test_set_for_calib}, batch_size=cfg.training.batch_size, num_workers=cfg.training.num_workers, dataset_type = cfg.dataset.name)

            results, results_file = basic_model_functionality(dataloaders=dataloaders, cfg=cfg, save_dir=os.path.join(save_dir, f'calibration_dist_{sample_ind}'), criterion=criterion, aux_criteria=aux_criteria)

            dist = results['train_results']['best_test_accuracy']
        else:
            # make dataloaders, can adjust as needed

            dataloaders = get_loaders({'train': train_set_for_calib, 'test': test_set_for_calib}, batch_size=cfg.pvalue.calibration_args.sample_size, num_workers=cfg.training.num_workers, dataset_type = cfg.dataset.name)

            baseline_dists = []
            for train_batch, test_batch in zip(dataloaders['train'], dataloaders['test']):

                batch_dist = distance_func_for_calibration(train_batch, test_batch) 
                baseline_dists.append(batch_dist)
            dist = torch.mean(torch.tensor(baseline_dists))
        end = time.time()
        print(f'Run {sample_ind} of calibration had distance {dist} and took {(end - start)/60.0} min')

        calibration_dists.append(float(dist))

    # ------- COMPUTING ACTUAL DISTANCE -------

    # put dataloaders back to only rotating some things
    datasets['train'].dataset.set_rotate_everything(False)
    datasets['val'].dataset.set_rotate_everything(False) 
    datasets['test'].dataset.set_rotate_everything(False)

    orig_start = time.time()
    actual_dists = []
    sample_size = pval_args.distance_args.sample_size
    train_sample_size = int(pval_args.distance_args.train_ratio * sample_size)
    test_sample_size = sample_size - train_sample_size
    print(f'train_sample_size {train_sample_size} test_sample_size {test_sample_size}')
    print(f'train_dataset_size {train_dataset_size} test_dataset_size {test_dataset_size}')
    for sample_ind in range(pval_args.distance_args.num_samples):

        train_inds = random.sample(range(train_dataset_size), train_sample_size)
        test_inds = random.sample(range(test_dataset_size), test_sample_size)

        train_set_for_calib = Subset(datasets['train'], train_inds)
        test_set_for_calib = Subset(datasets['test'], test_inds)

        start = time.time()
        if pval_method == 'classifier':

            dataloaders = get_loaders({'train': train_set_for_calib, 'test': test_set_for_calib}, batch_size=cfg.training.batch_size, num_workers=cfg.training.num_workers, dataset_type = cfg.dataset.name)

            results, results_file = basic_model_functionality(dataloaders=dataloaders, cfg=cfg, save_dir=os.path.join(save_dir, f'actual_dist_{sample_ind}'), criterion=criterion, aux_criteria=aux_criteria)

            dist = results['train_results']['best_test_accuracy']
        else:
            # make dataloaders, can adjust as needed

            dataloaders = get_loaders({'train': train_set_for_calib, 'test': test_set_for_calib}, batch_size=cfg.pvalue.distance_args.sample_size, num_workers=cfg.training.num_workers, dataset_type = cfg.dataset.name)

            baseline_dists = []
            for train_batch, test_batch in zip(dataloaders['train'], dataloaders['test']):

                batch_dist = distance_func_for_actual(train_batch, test_batch) # distance func needs to do the processing perhaps - make yet another wrapper I guess 
                baseline_dists.append(batch_dist)

            dist = torch.mean(torch.tensor(baseline_dists))
        end = time.time()
        print(f'Run {sample_ind} of actual had distance {dist} and took {(end - start)/60.0} min')

        actual_dists.append(float(dist))
    end = time.time()
    print(f'Computing the actual distances took {(end-orig_start)/60.0} min')

    # ------- COMPUTING THE P-VALUE -------
    avg_actual_dist = torch.mean(torch.tensor(actual_dists))
    diffs = (avg_actual_dist - torch.tensor(calibration_dists))
    dist_is_smaller = torch.where(diffs < 0)[0]
    p_value = (1 + len(dist_is_smaller)) / (1 + len(calibration_dists))

    return float(p_value), actual_dists, calibration_dists

def basic_model_functionality(dataloaders, cfg, save_dir, criterion=None, aux_criteria=None):

    checkpoint_dir = os.path.join(data_dir, os.path.join('checkpoints', save_dir)) # save the model somewhere with a lot of storage, then will make a symbolic link
    os.makedirs(checkpoint_dir, exist_ok=True)

    # make a new directory 
    os.makedirs(save_dir, exist_ok=True)
    print(f'Saving at {save_dir}')

    if cfg.training.use_wandb:
        primitive_cfg = omegaconf.OmegaConf.to_container(
            cfg, resolve=True, throw_on_missing=True
        )
        wandb.init(project='dist-symm-breaking',
        name=save_dir,
        tags=[cfg.dataset.name, cfg.dataset.task],
        group=cfg.dataset.name,
        config=primitive_cfg)

    if cfg.dataset.name == 'toycircle':
        train_viz_function = utils.plot_decision_boundary
    else:
        train_viz_function = None

    # if needed, get canonicalization module
    add_canon_params = False
    canon_model = None
    num_canon_param = 0
    if cfg.dataset.task == 'task_dependent': #'task_dependent_args' in cfg.dataset:
        c_args = cfg.dataset.task_dependent_args.c_args
        if c_args.learned:
            add_canon_params = True
        canon_model = get_canon_model(cfg)
        # compile it if it's E3?? or put compilation into the get_model code?
        num_canon_param = utils.count_parameters(canon_model)
        print(f'Number of parameters in canon_model: {num_canon_param}')

    if 'add_to_input_features' in cfg.model and cfg.model.add_to_input_features:
        cfg.model.model_args.input_features += cfg.model.add_to_input_features

    # get model
    model = get_model(cfg.model)
    # torch.compile enabled for e3nn
    
    num_param = utils.count_parameters(model)
    print(f'Number of parameters in model: {num_param}')

    # Log parameter counts to wandb if enabled
    if cfg.training.use_wandb:
        wandb.log({
            "model/num_parameters": num_param,
            "model/num_parameters_millions": num_param / 1e6,
            "model/model_name": cfg.model.name
        })
        if canon_model is not None:
            wandb.log({
                "model/num_canon_parameters": num_canon_param,
                "model/num_canon_parameters_millions": num_canon_param / 1e6,
                "model/total_parameters": num_param + num_canon_param,
                "model/total_parameters_millions": (num_param + num_canon_param) / 1e6,
                "model/canon_model_name": cfg.dataset.task_dependent_args.c_args.name
            })

    # determine if it has trainable param, and input into the optimizer if so

    # get optimizer

    if add_canon_params:
        optimizer = optim.Adam(list(model.parameters()) + list(canon_model.parameters()), lr=cfg.training.learning_rate)
    else:
        optimizer = optim.Adam(model.parameters(), lr=cfg.training.learning_rate)

    # load from checkpoint if desired
    if 'load_checkpoint_path' in cfg.training.keys():
        if cfg.training['load_checkpoint_path'] != 'None':
            print(f'Loading from checkpoint {cfg.training.load_checkpoint_path}')
            checkpoint = torch.load(cfg.training.load_checkpoint_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            # this probably needs more to integrate to wandb for it to make sense?

    train_results = train_model(model=model, 
    dataloaders=dataloaders, 
    criterion=criterion,  # e.g. cross entropy loss
    optimizer=optimizer, # TODO: add early stopping + evaluate test error 
    aux_criteria=aux_criteria,   # e.g. accuracy
    num_epochs=cfg.training.epochs, 
    device=device, 
    print_every=cfg.training.print_every, 
    verbose=cfg.training.verbose,
    use_wandb=cfg.training.use_wandb,
    dataset_cfg=cfg.dataset,
    save_dir=checkpoint_dir,
    train_viz_function=train_viz_function,
    canon_model=canon_model)

    # make a symbolic link
    #os.symlink(train_results['checkpoint_path'], os.path.join(save_dir, 'model.pt'))

    src = Path(train_results['checkpoint_path'])
    dst = Path(os.path.join(save_dir, 'model.pt'))

    # Remove existing destination if it exists
    if dst.exists() or dst.is_symlink():
        dst.unlink()

    # Create new symlink
    dst.symlink_to(src)

    results = {'train_results': train_results, 'config': cfg}
    results['num_param'] = num_param
    results['num_canon_param'] = num_canon_param
 
    # inside the directory: automatically save the relevant plots + train/val loss in a pickle file
    results_file = os.path.join(save_dir, 'results.pkl')
    with open(results_file, 'wb') as f:
        pickle.dump(results, f)
    
    # make plots to save
    utils.plot_results(train_results, aux_criteria=aux_criteria, save_prefix=os.path.join(save_dir, 'plot'))

    if cfg.training.use_wandb:
        wandb.finish()

    return results, results_file

def get_name(cfg):
    if 'save_dir' in cfg:
        save_dir = f'{cfg.save_dir}'
    else:
        if 'canonicalize_method' in cfg.dataset.augment_args.keys():
            canon_str = f'_{cfg.dataset.augment_args.canonicalize_method}'
        else:
            canon_str = ""
        save_dir = f'{cfg.dataset.name}-setting_{cfg.dataset.task}-full_ds_{cfg.dataset.split_args.use_full_dataset}-subset_{cfg.dataset.split_args.subset_size}-augment_{cfg.dataset.augment_args.do_augment}_transform_{cfg.dataset.augment_args.transform}-canon_{cfg.dataset.augment_args.canonicalize}{canon_str}-epochs_{cfg.training.epochs}-model_{cfg.model.name}'

    return save_dir

@hydra.main(config_path="configs", config_name="config") #version_base=None, 
def main(cfg):
    # Print the loaded configuration
    print("Configuration:\n", OmegaConf.to_yaml(cfg))

    # Set the seed at the beginning of your code # was previously right after the function was defined
    set_seed(cfg.training.training_seed)  # was 42

    # start profiler
    libcudart.cudaProfilerStart()

    save_dir = get_name(cfg)
    
    if os.path.exists(save_dir):
        save_dir = f'{save_dir}_{utils.get_timestamp_for_filename()}' 
    save_dir = os.path.join('results', save_dir)   

    if data_dir not in cfg.dataset:
        dataset_dir = os.path.join(data_dir, cfg.dataset.directory_name)
    os.makedirs(dataset_dir, exist_ok=True)
    dataset, criterion, aux_criteria, absolute_inds, label_operator = get_dataset(cfg, dataset_dir=dataset_dir)

    print(f'Checking splits', check_splits(absolute_inds['train'], absolute_inds['val'], absolute_inds['test']))
    
    # make subsets
    train_set, val_set, test_set = get_datasets_from_inds(absolute_inds, dataset)

    datasets = {'train': train_set, 'val': val_set, 'test': test_set}

    compute_pvalue = cfg.pvalue.compute_pvalue

    if compute_pvalue: 
        p_value, actual_dists, calibration_dists = compute_p_value(datasets, cfg, save_dir=save_dir, criterion=criterion, aux_criteria=aux_criteria)

        print(f'p-value: {p_value}')
        print('     actual dists', actual_dists)
        print('     calibration dists', calibration_dists)

        # save p_value results in save_dir
        p_value_res = {'p_value': p_value, 'actual_dists': actual_dists, 'calibration_dists': calibration_dists}

        os.makedirs(save_dir, exist_ok=True)
        filename = os.path.join(save_dir, 'pvalue.pkl')
        with open(filename, 'wb') as f:
            pickle.dump(p_value_res, f)

    else:
        # make dataloaders

        dataloaders = get_loaders(datasets, batch_size=cfg.training.batch_size, num_workers=cfg.training.num_workers, dataset_type = cfg.dataset.name)

        basic_model_functionality(dataloaders, cfg, save_dir, criterion=criterion, aux_criteria=aux_criteria)
        #print("classes", cfg.dataset.classes)

    # stop profiler
    libcudart.cudaProfilerStop()

if __name__ == "__main__":
    main()


