from torch_geometric.data import Data, Dataset, Batch
import torch
from datasets.aldp import ALDPDataset
from torch_geometric.loader import DataLoader
import logging 

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('ALDP-SequenceDataset')
class ALDPSequenceDataset(Dataset):
    """A Dataset that defines sequences of molecules to train the RNN propagator"""
    
    def __init__(self,args, sequence_length, sequence_indices,aldpDataset,transform=None):
        super(ALDPSequenceDataset, self).__init__(transform)

        self.sequence_length = sequence_length
        self.sequence_indices = sequence_indices # starting indices of the sequences in the aldp dataset
        self.tau = args.tau


        sequences = []
        bond_lengths = []
        bond_angles = []
        torsion_angles = []
        frames = []

        # iterate over sequences and get the corresponding graphs and frames for each sequence
        # i.e. one sequence is one item of this dataset
        for idx in sequence_indices:
            sequence = []
            sequence_bond_lengths = []
            sequence_bond_angles = []
            sequence_torsion_angles = []
            sequence_frames = []
            # iterate over all states in the sequence
            # +1 because the last frame is the target frame
            for i in range(sequence_length+1):
                # get the graph and frame for the current state of the sequence
                graph, bond_length , bond_angle, torsion_angle, frame = aldpDataset[idx+(i*self.tau)]
                sequence.append(graph)
                sequence_bond_lengths.append(bond_length)
                sequence_bond_angles.append(bond_angle)
                sequence_torsion_angles.append(torsion_angle)
                sequence_frames.append(frame)

            # the graphs of one sequence are loaded as a batch to perfom parallel processing of the sequence for the Encoder / Decoder
            sequences.append(Batch.from_data_list(sequence))   
            bond_length = torch.stack(sequence_bond_lengths) # (sequence_length, num_bonds ) 
            bond_lengths.append(bond_length)
            bond_angle = torch.stack(sequence_bond_angles) # (sequence_length, num_bonds )
            bond_angles.append(bond_angle)
            torsion_angle = torch.stack(sequence_torsion_angles) # (sequence_length, num_torsions )
            torsion_angles.append(torsion_angle)
            frame = torch.stack(sequence_frames) # (sequence_length, num_atoms ,3)
            frames.append(frame)
        
        # stack all sequences into a single tensor
        self.bond_lengths = torch.stack(bond_lengths) # (num_sequences, sequence_length, num_bonds )
        self.bond_angles = torch.stack(bond_angles) # (num_sequences, sequence_length, num_bonds )
        self.torsion_angles = torch.stack(torsion_angles) # (num_sequences, sequence_length, num_torsions )
        self.frames = torch.stack(frames) # (num_sequences, sequence_length, num_atoms ,3)
        self.sequences = sequences

        logger.info(f"Created dataset with {len(self)} sequences of length {sequence_length} and tau {self.tau}")

    def len(self):
        return len(self.sequence_indices)
    
    def get(self, idx):
        return self.sequences[idx], self.bond_lengths[idx], self.bond_angles[idx], self.torsion_angles[idx], self.frames[idx]
        


def construct_sequence_loader(args):
    sequence_length = args.sequence_length
    tau = args.tau

    aldpDataset = ALDPDataset(args,args.testsystem, transform=args.feature_transform)

    # +1 because the last frame is the target frame
    n_frames_per_sequence = (sequence_length+1)*tau
    num_sequences = len(aldpDataset) // n_frames_per_sequence

    # split into train, val and test
    seq_indices = torch.randperm(num_sequences, generator=torch.Generator().manual_seed(1999))
    train_indices = seq_indices[:int(0.7*num_sequences)] * n_frames_per_sequence
    val_indices = seq_indices[int(0.7*num_sequences):int(0.9*num_sequences)] * n_frames_per_sequence
    test_indices = seq_indices[int(0.9*num_sequences):] * n_frames_per_sequence

    train_dataset = ALDPSequenceDataset(args, sequence_length, train_indices,aldpDataset)
    val_dataset = ALDPSequenceDataset(args, sequence_length, val_indices,aldpDataset)
    # test_dataset = ALDPSequenceDataset(args, sequence_length, test_indices,aldpDataset)

    # batch_size refers to the number of graphs to process in parallel, not the number of sequences
    # compute correct batch size to create similar memory consumption as for --no_propagator mode
    # move to closest multiple of 2 to avoid inefficiencies in the DataLoader
    batch_size = min(2** torch.floor(torch.log2(torch.tensor(args.batch_size // (sequence_length +1)))).int().item(), 2** torch.floor(torch.log2(torch.tensor(len(val_dataset) ))).int().item())

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True , pin_memory=True , drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False , pin_memory=True , drop_last=True)
    # test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False , pin_memory=True)

    return train_loader, val_loader, aldpDataset.coordinate_transform , aldpDataset.inverse_transform

