import os
import sys

# Add parent directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))

import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import numpy as np
from config import get_default_config

def batch_random_rotation_matrices(batch_size: int):

    q = np.random.normal(size=(batch_size, 4))
    q /= np.linalg.norm(q, axis=1, keepdims=True)
    w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]

    R = np.empty((batch_size, 3, 3))
    R[:, 0, 0] = 1 - 2 * y**2 - 2 * z**2
    R[:, 0, 1] = 2 * x * y - 2 * w * z
    R[:, 0, 2] = 2 * x * z + 2 * w * y
    R[:, 1, 0] = 2 * x * y + 2 * w * z
    R[:, 1, 1] = 1 - 2 * x**2 - 2 * z**2
    R[:, 1, 2] = 2 * y * z - 2 * w * x
    R[:, 2, 0] = 2 * x * z - 2 * w * y
    R[:, 2, 1] = 2 * y * z + 2 * w * x
    R[:, 2, 2] = 1 - 2 * x**2 - 2 * y**2

    return R

def apply_random_rotations(batch: np.ndarray):

    bs = batch.shape[0]
    R_batch = batch_random_rotation_matrices(bs)  # (BS, 3, 3)

    batch_offset = np.mean(batch, axis=1, keepdims=True)  # (BS, 1, 3)
    batch_centered = batch - batch_offset  # (BS, num_atoms, 3)

    rotated = np.einsum("bij,bnj->bni", R_batch, batch_centered) + batch_offset

    return rotated

def corebeta_mapping(heavy_beads):
    corebeta = {}
    mapping = np.array([1, 3, 4, 5, 6, 8])
    corebeta['R'] = heavy_beads["R"][:, mapping, :]
    corebeta['F'] = heavy_beads["F"][:, mapping, :]
    corebeta['species'] = heavy_beads["species"][:, mapping]
    corebeta['mask'] = heavy_beads["mask"][:, mapping]
    corebeta['box'] = heavy_beads["box"]
    return corebeta

class Ala2Dataset(Dataset):
    def __init__(self, file_path, feat_type, cg_level):
        data = np.load(file_path, allow_pickle=True)
        if cg_level == "high":
            data = corebeta_mapping(data)

        self.positions = data['R']         # shape: (N, 10, 3)
        self.box = data['box']             # shape: (N, 3, 3)
        self.species = data['species']  # shape: (N, 10)
        if feat_type == "distinguish":
            self.features = np.arange(self.positions.shape[1], dtype=np.int32).reshape(-1, 1)  # shape: (10, 1)
        elif feat_type == "none":
            self.features = None
        elif feat_type == "species":
            self.features = self.species[0].reshape(-1, 1)

    def __len__(self):
        return len(self.positions)
    
    def __getitem__(self, idx):
        return {
            "x": self.positions[idx],
            "features": self.features,
        }
    
def collate_fn(batch):
   
    x_batch = np.stack([item["x"] for item in batch])  # (BS, num_atoms, 3)
    x_rotated = apply_random_rotations(x_batch)

    collated = {}
    for k in batch[0].keys():
        if k == "x":
            collated[k] = x_rotated
        else:
            collated[k] = np.stack([item[k] for item in batch])
    return collated

def get_ala2_dataloader(config):

    cg_level = config["general"]["cg_level"]
    file_path = config["dataset"]["ala2_datafile"]
    num_samples = config["dataset"]["num_samples"]
    shuffle = config["dataset"]["shuffle"]
    drop_last = config["dataset"]["drop_last"]
    feat_type = config["model"]["GraphTransformer"]["feat_type"]
    batch_size = config["trainer"]["batch"]

    dataset = Ala2Dataset(file_path=file_path, feat_type=feat_type, cg_level=cg_level)
    print(f"Dataset size: {len(dataset)}")
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    sampler = SubsetRandomSampler(indices)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn)
    return dataloader

if __name__ == "__main__":
    config = get_default_config()
    dataloader = get_ala2_dataloader(config)
    for batch in dataloader:
        print("x shape:", batch["x"].shape)
        print("features shape:", batch["features"].shape)
        print(f"fist batch features: {batch['features'][0]}")
        break