import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from collections import defaultdict
from os.path import join
import random
import os



def get_loader(
    data_path,
    filenames,
    device,
    gpu,
    labels_dict,
    Mol_dict,
    split_files_with_boxes,
    bs_dict,
    shuffle,
    num_workers,
    n_gpus,
    rank=0,
    seed=1,
    return_names=False,
    augment=True,
):

    dataset = CNN3D_Dataset(
        data_path=data_path,
        filenames=filenames,
        device=device,
        gpu=gpu,
        labels_dict=labels_dict,
        split_files_with_boxes=split_files_with_boxes,
        num_classes =1,
        return_names=return_names,
        augment=augment
    )

    sampler = BoxSizeBatchSampler(
        dataset=dataset,
        bs_dict=bs_dict,
        shuffle=shuffle,
        num_replicas=n_gpus,
        rank=rank,
        seed=seed
    )

    loader = DataLoader(
        dataset=dataset,
        batch_sampler=sampler,
        num_workers=num_workers,
        pin_memory=True
    )
    return loader

def get_loader_inference(
    data_path,
    filenames,
    device,
    gpu,
    Mol_dict,
    split_files_with_boxes,
    bs_dict,
    shuffle,
    num_workers,
    n_gpus,
    rank=0,
    seed=1,
    return_names=False,
    augment=True,
):

    dataset = CNN3D_Dataset_inference(
        data_path=data_path,
        filenames=filenames,
        device=device,
        gpu=gpu,
        split_files_with_boxes=split_files_with_boxes,
        num_classes =1,
        return_names=return_names,
        augment=augment
    )

    sampler = BoxSizeBatchSampler(
        dataset=dataset,
        bs_dict=bs_dict,
        shuffle=shuffle,
        num_replicas=n_gpus,
        rank=rank,
        seed=seed
    )

    loader = DataLoader(
        dataset=dataset,
        batch_sampler=sampler,
        num_workers=num_workers,
        pin_memory=True
    )
    return loader


def get_loader_SM(
    data_path,
    filenames,
    device,
    gpu,
    labels_dict,
    Mol_dict,
    split_files_with_boxes,
    bs_dict,
    shuffle,
    num_workers,
    n_gpus,
    rank=0,
    seed=1,
    return_names=False,
    augment=True
):

    dataset = CNN3D_Dataset_SM(
        data_path=data_path,
        filenames=filenames,
        device=device,
        gpu=gpu,
        labels_dict=labels_dict,
        Mol_dict=Mol_dict,
        split_files_with_boxes=split_files_with_boxes,
        num_classes =1,
        return_names=return_names,
        augment=augment
    )

    sampler = BoxSizeBatchSampler(
        dataset=dataset,
        bs_dict=bs_dict,
        shuffle=shuffle,
        num_replicas=n_gpus,
        rank=rank,
        seed=seed
    )

    loader = DataLoader(
        dataset=dataset,
        batch_sampler=sampler,
        num_workers=num_workers,
        pin_memory=True
    )
    return loader

def get_loader_SM_inference(
    data_path,
    filenames,
    device,
    gpu,
    Mol_dict,
    split_files_with_boxes,
    bs_dict,
    shuffle,
    num_workers,
    n_gpus,
    rank=0,
    seed=1,
    return_names=False,
    augment=True
):

    dataset = CNN3D_Dataset_SM_inference(
        data_path=data_path,
        filenames=filenames,
        device=device,
        gpu=gpu,
        Mol_dict=Mol_dict,
        split_files_with_boxes=split_files_with_boxes,
        num_classes =1,
        return_names=return_names,
        augment=augment
    )

    sampler = BoxSizeBatchSampler(
        dataset=dataset,
        bs_dict=bs_dict,
        shuffle=shuffle,
        num_replicas=n_gpus,
        rank=rank,
        seed=seed
    )

    loader = DataLoader(
        dataset=dataset,
        batch_sampler=sampler,
        num_workers=num_workers,
        pin_memory=True
    )
    return loader



def get_loader_CL(
    data_path,
    filenames,
    device,
    gpu,
    split_files_with_boxes,
    esm_dict,
    bs_dict,
    shuffle,
    num_workers,
    n_gpus,
    rank=0,
    seed=1,
    return_names=False,
    augment=True
):

    dataset = CNN3D_Dataset_CL(
        data_path=data_path,
        filenames=filenames,
        device=device,
        gpu=gpu,
        esm_dict=esm_dict,
        split_files_with_boxes=split_files_with_boxes,
        return_names=return_names,
        augment= augment
    )

    sampler = BoxSizeBatchSampler(
        dataset=dataset,
        bs_dict=bs_dict,
        shuffle=shuffle,
        num_replicas=n_gpus,
        rank=rank,
        seed=seed
    )

    loader = DataLoader(
        dataset=dataset,
        batch_sampler=sampler,
        num_workers=num_workers,
        pin_memory=True
    )
    return loader




class CNN3D_Dataset(Dataset):
    def __init__(
        self,
        data_path,
        filenames,
        device,
        gpu,
        labels_dict,
        split_files_with_boxes,
        num_classes,
        return_names=False,
        augment = True
    ):
        self.data_path = data_path
        self.device = device
        self.gpu = gpu
        self.filenames = filenames
        self.labels_dict = labels_dict
        self.split_files_with_boxes = split_files_with_boxes
        self.num_classes = num_classes
        self.return_names = return_names
        self.augment = augment

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        name = self.filenames[idx]
        try:
            y = self.labels_dict[name]
        except:
            y = self.labels_dict[name+ ".pdb"]
        box_size = self.get_box_size(idx)
        X = get_input_matrix(join(self.data_path, "numpy_3D_point_lists", name + ".npy"), box_size, augment = self.augment)
        X = torch.from_numpy(X.astype(np.float32))
        y = torch.tensor(y)
        if self.return_names:
            return X, y, name
        return X, y

    def get_box_size(self, idx):
        box_size_str = self.split_files_with_boxes[self.filenames[idx]]
        # Remove the square brackets and split by whitespace
        box_size_list = box_size_str.strip('[]').split()
        try:
            # Convert each part to an integer
            return [int(x) for x in box_size_list]
        except ValueError:
            raise ValueError(f"Invalid numerical values in box size: {box_size_str}")   

class CNN3D_Dataset_inference(Dataset):
    def __init__(
        self,
        data_path,
        filenames,
        device,
        gpu,
        split_files_with_boxes,
        return_names=False,
        augment = True
    ):
        self.data_path = data_path
        self.device = device
        self.gpu = gpu
        self.filenames = filenames
        self.split_files_with_boxes = split_files_with_boxes
        self.return_names = return_names
        self.augment = augment

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        name = self.filenames[idx]
        box_size = self.get_box_size(idx)
        X = get_input_matrix(join(self.data_path, "numpy_3D_point_lists", name + ".npy"), box_size, augment = self.augment)
        X = torch.from_numpy(X.astype(np.float32))
        if self.return_names:
            return X, name
        return X

    def get_box_size(self, idx):
        box_size_str = self.split_files_with_boxes[self.filenames[idx]]
        # Remove the square brackets and split by whitespace
        box_size_list = box_size_str.strip('[]').split()
        try:
            # Convert each part to an integer
            return [int(x) for x in box_size_list]
        except ValueError:
            raise ValueError(f"Invalid numerical values in box size: {box_size_str}")   
        

class CNN3D_Dataset_SM(Dataset):
    def __init__(
        self,
        data_path,
        filenames,
        device,
        gpu,
        labels_dict,
        Mol_dict,
        split_files_with_boxes,
        num_classes,
        return_names=False,
        augment = True
    ):
        self.data_path = data_path
        self.device = device
        self.gpu = gpu
        self.filenames = filenames
        self.labels_dict = labels_dict
        self.split_files_with_boxes = split_files_with_boxes
        self.num_classes = num_classes
        self.Mol_dict = Mol_dict
        self.return_names = return_names
        self.augment = augment
        self.max_smiles_seq_len = 256

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        name = self.filenames[idx]
        UID = name.split("_")[0]
        MID = name.split("_")[1]
        y = self.labels_dict[name]
        
        box_size = self.get_box_size(idx)
        X = get_input_matrix(join(self.data_path, "numpy_3D_point_lists", UID + ".npy"), box_size, augment = self.augment)
        X = torch.from_numpy(X.astype(np.float32))
        mol_data = self.Mol_dict[MID]
        if isinstance(mol_data, torch.Tensor):
            smiles_emb = mol_data.clone().detach().squeeze()
        else:
            smiles_emb = torch.from_numpy(mol_data).clone().detach().squeeze()

        seq_len      = smiles_emb.shape[0]
        padding_len  = self.max_smiles_seq_len - seq_len

        smiles_emb = torch.nn.functional.pad(
            smiles_emb, (0, 0, 0, padding_len), mode='constant', value=0
        )

        smiles_attn_mask = torch.zeros(self.max_smiles_seq_len, dtype=torch.bool)
        smiles_attn_mask[seq_len:] = True
        y = torch.tensor(y)
        if self.return_names:
            return X, smiles_emb, smiles_attn_mask, y, name
        return X, smiles_emb, smiles_attn_mask, y

    def get_box_size(self, idx):
        box_size_str = self.split_files_with_boxes[self.filenames[idx]]
        # Remove the square brackets and split by whitespace
        box_size_list = box_size_str.strip('[]').split()
        try:
            # Convert each part to an integer
            return [int(x) for x in box_size_list]
        except ValueError:
            raise ValueError(f"Invalid numerical values in box size: {box_size_str}")


class CNN3D_Dataset_SM_inference(Dataset):
    def __init__(
        self,
        data_path,
        filenames,
        device,
        gpu,
        Mol_dict,
        split_files_with_boxes,
        num_classes,
        return_names=False,
        augment = True
    ):
        self.data_path = data_path
        self.device = device
        self.gpu = gpu
        self.filenames = filenames
        self.split_files_with_boxes = split_files_with_boxes
        self.Mol_dict = Mol_dict
        self.return_names = return_names
        self.augment = augment
        self.max_smiles_seq_len = 256

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        name = self.filenames[idx]
        UID = name.split("_")[0]
        MID = name.split("_")[1]
        
        box_size = self.get_box_size(idx)
        X = get_input_matrix(join(self.data_path, "numpy_3D_point_lists", UID + ".npy"), box_size, augment = self.augment)
        X = torch.from_numpy(X.astype(np.float32))
        mol_data = self.Mol_dict[MID]
        if isinstance(mol_data, torch.Tensor):
            smiles_emb = mol_data.clone().detach().squeeze()
        else:
            smiles_emb = torch.from_numpy(mol_data).clone().detach().squeeze()

        seq_len      = smiles_emb.shape[0]
        padding_len  = self.max_smiles_seq_len - seq_len

        smiles_emb = torch.nn.functional.pad(
            smiles_emb, (0, 0, 0, padding_len), mode='constant', value=0
        )

        smiles_attn_mask = torch.zeros(self.max_smiles_seq_len, dtype=torch.bool)
        smiles_attn_mask[seq_len:] = True
        if self.return_names:
            return X, smiles_emb, smiles_attn_mask, name
        return X, smiles_emb, smiles_attn_mask

    def get_box_size(self, idx):
        box_size_str = self.split_files_with_boxes[self.filenames[idx]]
        # Remove the square brackets and split by whitespace
        box_size_list = box_size_str.strip('[]').split()
        try:
            # Convert each part to an integer
            return [int(x) for x in box_size_list]
        except ValueError:
            raise ValueError(f"Invalid numerical values in box size: {box_size_str}")

class CNN3D_Dataset_CL(Dataset):
    def __init__(
        self,
        data_path,
        filenames,
        device,
        gpu,
        esm_dict,
        split_files_with_boxes,
        return_names=False,
        augment = True
    ):
        self.data_path = data_path
        self.device = device
        self.gpu = gpu
        self.filenames = filenames
        self.esm_dict = esm_dict
        self.split_files_with_boxes = split_files_with_boxes
        self.return_names = return_names
        self.augment = augment

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        name = self.filenames[idx]
        UID = name
        box_size = self.get_box_size(idx)
        X = get_input_matrix(join(self.data_path, "numpy_3D_point_lists", f"AF-{name}-F1-model_v4.npy"), box_size, augment = self.augment)
        X = torch.from_numpy(X.astype(np.float32))
        esm = self.esm_dict[UID]
        esm = torch.from_numpy(esm.astype(np.float32))
        if self.return_names:
            return X, esm, name
        return X, esm

    def get_box_size(self, idx):
        box_size_str = self.split_files_with_boxes[self.filenames[idx]]
        # Remove the square brackets and split by whitespace
        box_size_list = box_size_str.strip('[]').split()
        try:
            # Convert each part to an integer
            return [int(x) for x in box_size_list]
        except ValueError:
            raise ValueError(f"Invalid numerical values in box size: {box_size_str}")    
        


def get_input_matrix(dict_filename, box_size, augment = True):
    
    try:
        arr = np.load(dict_filename)
    except:
        directory = os.path.dirname(dict_filename)
        filename = os.path.basename(dict_filename).split(".")[0]
        new_filename = "AF-" + filename + "-F1-model_v4.npy"
        arr = np.load(join(directory, new_filename))
    indices = arr[:, :4].astype(int)
    v = arr[:, 4]
    X = np.zeros((box_size[0], box_size[1], box_size[2], 5), dtype=np.float32)
    X[indices[:, 0], indices[:, 1], indices[:, 2], indices[:, 3]] = v

    if augment:
        ran = np.random.randint(6)
        if ran == 0 and box_size[0] == box_size[1]:
            X = np.rot90(X, axes=(0, 1))
        elif ran == 1 and box_size[1] == box_size[2]:
            X = np.rot90(X, axes=(1, 2))
        elif ran == 2 and box_size[0] == box_size[2]:
            X = np.rot90(X, axes=(0, 2))
        elif ran == 3:
            X = np.flip(X, axis=0)
        elif ran == 4:
            X = np.flip(X, axis=1)

    X = X.transpose(3, 0, 1, 2)
    return X

class BoxSizeBatchSampler(Sampler):
    def __init__(self, dataset, bs_dict, shuffle=True, num_replicas=1, rank=0, seed=42):
        """
        Initializes the BoxSizeBatchSampler.

        Args:
            dataset (Dataset): The dataset to sample from.
            bs_dict (dict): A dictionary mapping box size strings to batch sizes.
            shuffle (bool): Whether to shuffle the data.
            num_replicas (int): Number of processes participating in distributed training.
            rank (int): Rank of the current process within num_replicas.
            seed (int): Random seed for shuffling to ensure reproducibility across replicas.
        """
        self.dataset = dataset
        self.bs_dict = bs_dict
        self.shuffle = shuffle
        self.num_replicas = num_replicas
        self.rank = rank
        self.seed = seed
        self.box_size_to_indices = defaultdict(list)
        self.rng = random.Random(self.seed)  # Local RNG for deterministic shuffling

        self._group_indices_by_box_size()
        self._create_batches()
        self._create_steps()
        self._assign_batches_to_replica()

    def _group_indices_by_box_size(self):
        """
        Groups dataset indices by their corresponding box sizes.
        """
        for idx in range(len(self.dataset)):
            size_str = str(self.dataset.get_box_size(idx)).replace(",", "")
            self.box_size_to_indices[size_str].append(idx)

    def _create_batches(self):
        """
        Creates batches for each box size based on the batch size dictionary.
        Ensures minimum batch size of 2 to avoid BatchNorm issues during training.
        """
        self.batches_per_box = {}
        for size_str, indices in self.box_size_to_indices.items():
            if self.shuffle:
                self.rng.shuffle(indices)
            batch_size = max(1, self.bs_dict.get(size_str, 32))  # Ensure minimum batch size of 2

            # If we have fewer samples than minimum batch size, duplicate the last sample
            if len(indices) == 1:
                indices = indices + indices  # Duplicate the single sample

            # Create batches for this box size
            size_batches = [indices[i:i + batch_size] for i in range(0, len(indices), batch_size)]
            
            # Handle last batch if it's smaller than minimum size
            if len(size_batches[-1]) == 1:
                if len(size_batches) > 1:
                    # Merge with previous batch if exists
                    size_batches[-2].extend(size_batches[-1])
                    size_batches.pop()
                else:
                    # Duplicate the sample if it's the only batch
                    size_batches[-1].extend(size_batches[-1])

            if self.shuffle:
                self.rng.shuffle(size_batches)  # Shuffle batches within the box size

            self.batches_per_box[size_str] = size_batches

    def _create_steps(self):
        """
        Creates training steps where each step contains `num_replicas` batches of the same box size.
        """
        self.steps = []
        box_sizes = list(self.batches_per_box.keys())

        if self.shuffle:
            self.rng.shuffle(box_sizes)  # Shuffle the order of box sizes

        # Create a list of iterators for each box size
        box_iterators = {size: iter(batches) for size, batches in self.batches_per_box.items()}
        remaining_batches = {size: len(batches) for size, batches in self.batches_per_box.items()}
        
        while box_iterators:
            # Select the next box size to create a step
            # If shuffle is enabled, randomly select a box size; otherwise, iterate in order
            if self.shuffle:
                available_sizes = list(box_iterators.keys())
                size_weights = [remaining_batches[size] for size in available_sizes]
                # Normalize weights to probabilities
                size_probs = [weight / sum(size_weights) for weight in size_weights]
                selected_size = self.rng.choices(available_sizes, weights=size_probs, k=1)[0]
            else:
                selected_size = next(iter(box_iterators))

            step_batches = []
            iterator = box_iterators[selected_size]
            try:
                for _ in range(self.num_replicas):
                    batch = next(iterator)
                    step_batches.append(batch)
            except StopIteration:
                # If not enough batches to form a complete step, decide whether to drop or handle
                # Here, we choose to drop incomplete steps
                step_batches = None

            if step_batches and len(step_batches) == self.num_replicas:
                self.steps.append((selected_size, step_batches))
            else:
                # Remove the exhausted box size
                del box_iterators[selected_size]


    def _assign_batches_to_replica(self):
        """
        Assigns batches to the current replica based on its rank.
        """
        self.batches = []
        for size, step_batches in self.steps:
            batch = step_batches[self.rank]
            self.batches.append(batch)

    def __iter__(self):
        """
        Yields batches assigned to the current replica.
        """
        for batch in self.batches:
            yield batch

    def __len__(self):
        """
        Returns the number of batches assigned to the current replica.
        """
        return len(self.batches)