import torch
import numpy as np
from torch.utils import data
from collections import defaultdict
from motiflow.utils.rigid_helpers import assemble_rigid_mat


class FragmentDataset(data.Dataset):
    def __init__(self, split, data_conf, is_training):
        self._is_training = is_training
        self._data_conf = data_conf
        
        # load the split data
        self.data_list = torch.load(f"{data_conf.processed_data_dir}/{split}_processed.pt")
        self.cond_type = data_conf.conditioning_type
        self.library = torch.load(data_conf.library_path)
        
        self.S_max = int(data_conf.S_max)
        self.vocab_size = int(data_conf.vocab_size)

        # pre-process library symmetries into tensors
        self.lib_symmetries = {}
        for cid, entry in self.library.items():
            syms = torch.as_tensor(entry['symmetries'], dtype=torch.float32) # (num_syms, 3, 3)
            num_syms = syms.shape[0]
            
            if num_syms > self.S_max:
                raise ValueError(f"Class {cid} has {num_syms} symmetries, > S_max {self.S_max}")
            
            # pad rotational symmetries with identity matrices
            padded_syms = torch.eye(3).unsqueeze(0).repeat(self.S_max, 1, 1) # (S_max, 3, 3)
            padded_syms[:num_syms] = syms
            
            # mask (True for valid symmetries, False for padded)
            mask = torch.zeros(self.S_max, dtype=torch.bool)
            mask[:num_syms] = True
            
            self.lib_symmetries[cid] = (padded_syms, mask)

        # add a unit matrix for the [MASK] state symmetry
        mask_mask = torch.zeros(self.S_max, dtype=torch.bool)
        mask_mask[0] = True
        self.lib_symmetries[self.vocab_size] = (
            torch.eye(3).unsqueeze(0).repeat(self.S_max, 1, 1),  # (S_max, 3, 3)
            mask_mask
        )
    
    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        if idx is None:
            return None

        item = self.data_list[idx]

        # extract the raw data
        frag_ids = item['frag_ids'].long()
        trans_gt = item['trans'].float()
        rots_gt = item['rots'].float().clone()
        K = frag_ids.shape[0]

        frag_mask = torch.ones(K, dtype=torch.float32)

        # center the fragments
        center_of_frags = trans_gt.mean(dim=0, keepdim=True)
        trans_gt_centered = trans_gt - center_of_frags

        # build tensor for rotational symmetries for each fragment and mask
        mol_syms = torch.zeros((K, self.S_max, 3, 3), dtype=torch.float32)
        mol_sym_mask = torch.zeros((K, self.S_max), dtype=torch.bool)

        for k in range(K):
            cid = int(frag_ids[k].item())
            # load symmetries
            if cid in self.lib_symmetries:
                s, m = self.lib_symmetries[cid]
            else:
                raise ValueError(f"Fragment ID {cid} not found in library.")
            # 1. Fill Output Tensors
            mol_syms[k] = s
            mol_sym_mask[k] = m

            # # 2. Apply Random Frame Augmentation (Training Only) if Enabled
            if self._is_training and self._data_conf.use_frame_augmentation:
                num_valid = int(m.sum().item())
                
                if num_valid > 1:
                    # Sample index: 0 to num_valid-1
                    rand_idx = torch.randint(0, num_valid, (1,)).item()
                    
                    # If rand_idx == 0 (Identity), rots_gt remains unchanged. 
                    # If > 0, we apply the discrete symmetry.
                    if rand_idx > 0:
                        S = s[rand_idx]
                        # R_new = R_old @ S (Rotate the local frame)
                        rots_gt[k] = torch.matmul(rots_gt[k], S)

        # Assemble clean centered rigids from the (potentially augmented) rotations
        rigids_0_obj = assemble_rigid_mat(rots_gt, trans_gt_centered)
        # convert rigids_0 to tensor_7 format
        rigids_0 = rigids_0_obj.to_tensor_7()
        
        cond_data = {}
        if self.cond_type == "composition":
            cond_data['condition'] = item['composition_counts'].float() / 29 # normalize atom counts by max atoms in QM9
        elif self.cond_type == "structure":
            cond_data['condition'] = item['fingerprint'].float()
            
        # Load Evaluation Target (Ground Truth)
        eval_target = torch.zeros(1) 
        
        if self.cond_type == "composition":
            # Target is the integer counts (e.g., [6, 12, 0, 6, 0])
            eval_target = item['composition_counts'].float()
        elif self.cond_type == "structure":
            # Target is the fingerprint
            eval_target = item['fingerprint'].float()

        # 4. Assemble Output
        out = {
            "frag_ids": frag_ids,
            "frag_mask": frag_mask,
            "rigids_0": rigids_0,         
            "symmetries": mol_syms,
            "sym_mask": mol_sym_mask,
            "eval_target": eval_target,
            **cond_data
        }

        return out


class TrainSampler(data.Sampler):
    def __init__(self, dataset, batch_size, include_partial_batches=True):
        """
        Groups dataset indices by sequence length and yields batches grouped by length.
        By default `include_partial_batches=True` so the leftover (partial) batch per length
        is included instead of being dropped.

        Args:
            dataset: dataset with `data_list` accessible and each item having 'frag_ids'.
            batch_size: desired batch_size used by DataLoader to collate mini-batches.
            include_partial_batches: if True, include the last smaller batch for each length group.
        """
        self.dataset = dataset
        self.batch_size = batch_size
        self.epoch = 0
        self.include_partial_batches = include_partial_batches

        # Build length -> indices mapping
        self.length_to_indices = defaultdict(list)
        for idx, item in enumerate(self.dataset.data_list):
            k = len(item["frag_ids"])
            self.length_to_indices[k].append(int(idx))

        # Compute num_samples (accurate)
        self.num_samples = 0
        for k, indices in self.length_to_indices.items():
            if self.include_partial_batches:
                self.num_samples += len(indices)
            else:
                self.num_samples += (len(indices) // self.batch_size) * self.batch_size

    def __iter__(self):
        # deterministic per-epoch RNG
        rng = np.random.default_rng(self.epoch)
        batches = []

        for length, indices in self.length_to_indices.items():
            # shuffle indices of this length
            indices = np.array(indices)
            rng.shuffle(indices)

            # full batches
            num_full = len(indices) // self.batch_size
            for i in range(num_full):
                batch = indices[i * self.batch_size : (i + 1) * self.batch_size]
                batches.append(batch)

            # leftover partial batch (kept if include_partial_batches)
            rem = len(indices) - num_full * self.batch_size
            if self.include_partial_batches and rem > 0:
                last_batch = indices[num_full * self.batch_size :]
                batches.append(last_batch)

        # shuffle order of batches across lengths
        rng.shuffle(batches)

        # flatten into index stream
        final_indices = [int(idx) for batch in batches for idx in batch]
        return iter(final_indices)

    def set_epoch(self, epoch):
        self.epoch = int(epoch)

    def __len__(self):
        return int(self.num_samples)
