import random
import numpy as np
import torch
from monai import transforms


vq_tensor_mapping = {
    'Dxx': 0,
    'Dyy': 0,
    'Dzz': 0,
    'Dxy': 1,
    'Dxz': 1,
    'Dyz': 1,
    'b0': 2,
    'b1000x': 3,
    'b1000y': 3,
    'b1000z': 3,
}
tensor_minmaxv_mapping = [[0.0, 0.002], [-0.0005, 0.0005]]
diffusion_tensor_mapping = {
    'b0': 0,
    'b1000x': 1,
    'b1000y': 2,
    'b1000z': 3,
    'Dxx': 4,
    'Dyy': 5,
    'Dzz': 6,
    'Dxy': 7,
    'Dxz': 8,
    'Dyz': 9,
}

class LastTransform(transforms.MapTransform):
    def __init__(
        self,
        keys = ('x', 'y', 'mask'),
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

    def __call__(self, data):
        d = dict(data)
        
        data = d['data']
        data_idx = d['data_idx']
        data_type = d['data_type']
        mask = d['mask']
        minmaxv = d['minmaxv']
        
        minv, maxv = minmaxv[0], minmaxv[1]
        data = (data - minv) / (maxv - minv + 1e-8)
        data = data * 2 - 1
        data = data * mask
        data = torch.clamp(data, -1, 1) if data_idx < 4 else torch.clamp(data, -2, 2)
        
        vq_mapping_index = vq_tensor_mapping[data_type]
        diffusion_mapping_index = diffusion_tensor_mapping[data_type]
        
        d['data'] = data
        d['vq_mapping_index'] = vq_mapping_index
        d['diffusion_mapping_index'] = diffusion_mapping_index
        
        return d
    
    
def get_train_transforms():
    train_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["data", "mask"], channel_dim='no_channel'),
            transforms.RandFlipd(keys=["data", "mask"], prob=0.5, spatial_axis=0),
            transforms.RandAffined(
                keys=["data", "mask"],
                mode=("bilinear", "nearest"),
                prob=0.9,
                rotate_range=((-np.pi/36, np.pi/36), (-np.pi/36, np.pi/36)),
                translate_range=None,
                scale_range=(0.001, 0.001),
                padding_mode="border",
            ),
            LastTransform(),
            # transforms.RandSpatialCropd(
            #     keys=["x", "y", "mask"],
            #     roi_size=[128, 128, 64],
            #     random_size=False,
            # ),
            transforms.ToTensord(keys=["data", "mask"]),
            transforms.ToTensord(keys=["vq_mapping_index", "diffusion_mapping_index"], dtype=torch.long),
        ]
    )
    return train_transform

def get_val_transforms():
    val_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["data", "mask"], channel_dim='no_channel'),
            LastTransform(),
            # transforms.CenterSpatialCropd(
            #     keys=["x", "y", "mask"],
            #     roi_size=[128, 128, 64],
            # ),
            transforms.ToTensord(keys=["data", "mask"]),
            transforms.ToTensord(keys=["vq_mapping_index", "diffusion_mapping_index"], dtype=torch.long),
        ]
    )
    return val_transform

def get_test_transforms():
    val_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["data", "mask"], channel_dim='no_channel'),
            LastTransform(),
            # transforms.CenterSpatialCropd(
            #     keys=["x", "y", "mask"],
            #     roi_size=[128, 128, 64],
            # ),
            transforms.ToTensord(keys=["data", "mask"]),
            transforms.ToTensord(keys=["vq_mapping_index", "diffusion_mapping_index"], dtype=torch.long),
        ]
    )
    return val_transform

# def get_test_transforms(x_image, y_image):
#     test_transform = transforms.Compose(
#         [
#             transforms.EnsureChannelFirstd(keys=["x", "y", "mask"], channel_dim='no_channel'),
#             LastTransform(
#                 x_image=x_image,
#                 y_image=y_image,
#             ),
#             transforms.ToTensord(keys=["x", "y", "mask"]),
#         ]
#     )
#     return test_transform
