import random
import numpy as np
import torch
from monai import transforms

tensor_mapping = {
    'Dxx': 0,
    'Dyy': 0,
    'Dzz': 0,
    'Dxy': 1,
    'Dxz': 1,
    'Dyz': 1,
}
tensor_minmaxv_mapping = [[0.0, 0.002], [-0.0005, 0.0005]]

class LastTransform(transforms.MapTransform):
    def __init__(
        self,
        keys = ('x', 'y', 'mask'),
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

    def __call__(self, data):
        d = dict(data)
        
        tensor = d['tensor']
        tensor_direction = d['tensor_direction']
        b0 = d['b0']
        b1000x = d['b1000x']
        b1000y = d['b1000y']
        b1000z = d['b1000z']
        mask = d['mask']
        b1000xyz_dir = d['b1000xyz_dir']
        minmaxv = d['minmaxv']
        
        tensor_mapping_index = tensor_mapping[tensor_direction]
        tensor_minmaxv = tensor_minmaxv_mapping[tensor_mapping_index]
        minv_tensor, maxv_tensor = tensor_minmaxv[0], tensor_minmaxv[1]
        
        # tensor = torch.clamp(tensor, minv_tensor, maxv_tensor)
        tensor = (tensor - minv_tensor) / (maxv_tensor - minv_tensor + 1e-8)
        tensor = tensor * 2 - 1
        tensor = tensor * mask
        tensor = torch.clamp(tensor, -2, 2)
        
        minv_b0, maxv_b0 = 0.0, minmaxv[1]
        minv_b1000, maxv_b1000 = 0.0, minmaxv[3]
        
        b0 = (b0 - minv_b0) / (maxv_b0 - minv_b0 + 1e-8)
        b0 = b0 * 2 - 1
        b0 = b0 * mask
        b0 = torch.clamp(b0, -1, 1)
        
        b1000x = (b1000x - minv_b1000) / (maxv_b1000 - minv_b1000 + 1e-8)
        b1000x = b1000x * 2 - 1
        b1000x = b1000x * mask
        b1000x = torch.clamp(b1000x, -1, 1)
        
        b1000y = (b1000y - minv_b1000) / (maxv_b1000 - minv_b1000 + 1e-8)
        b1000y = b1000y * 2 - 1
        b1000y = b1000y * mask
        b1000y = torch.clamp(b1000y, -1, 1)
        
        b1000z = (b1000z - minv_b1000) / (maxv_b1000 - minv_b1000 + 1e-8)
        b1000z = b1000z * 2 - 1
        b1000z = b1000z * mask
        b1000z = torch.clamp(b1000z, -1, 1)
        
        d['tensor'] = tensor
        d['tensor_mapping_index'] = tensor_mapping_index
        d['b0'] = b0
        d['b1000x'] = b1000x
        d['b1000y'] = b1000y
        d['b1000z'] = b1000z
        
        return d
    
    
def get_train_transforms():
    train_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["tensor", "b0", "b1000x", "b1000y", "b1000z", "mask"], channel_dim='no_channel'),
            transforms.RandFlipd(keys=["tensor", "b0", "b1000x", "b1000y", "b1000z", "mask"], prob=0.5, spatial_axis=0),
            transforms.RandAffined(
                keys=["tensor", "b0", "b1000x", "b1000y", "b1000z", "mask"],
                mode=("bilinear", "bilinear", "bilinear", "bilinear", "bilinear", "nearest"),
                prob=0.9,
                rotate_range=((-np.pi/36, np.pi/36), (-np.pi/36, np.pi/36), (-np.pi/36, np.pi/36)),
                translate_range=None,
                scale_range=(0.001, 0.001),
                padding_mode="border",
            ),
            LastTransform(),
            transforms.ToTensord(keys=["tensor", "tensor_mapping_index", "b0", "b1000x", "b1000y", "b1000z", "mask"]),
        ]
    )
    return train_transform

def get_val_transforms():
    val_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["tensor", "b0", "b1000x", "b1000y", "b1000z", "mask"], channel_dim='no_channel'),
            LastTransform(),
            transforms.ToTensord(keys=["tensor", "tensor_mapping_index", "b0", "b1000x", "b1000y", "b1000z", "mask"]),
        ]
    )
    return val_transform

def get_test_transforms():
    val_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["tensor", "b0", "b1000x", "b1000y", "b1000z", "mask"], channel_dim='no_channel'),
            LastTransform(),
            transforms.ToTensord(keys=["tensor", "tensor_mapping_index", "b0", "b1000x", "b1000y", "b1000z", "mask"]),
        ]
    )
    return val_transform


