import random
import numpy as np
import torch
from monai import transforms

tensor_mapping = {
    'Dxx': 0,
    'Dyy': 0,
    'Dzz': 0,
    'Dxy': 1,
    'Dxz': 1,
    'Dyz': 1,
}
tensor_minmaxv_mapping = [[0.0, 0.002], [-0.0005, 0.0005]]

class LastTransform(transforms.MapTransform):
    def __init__(
        self,
        keys = ('x', 'y', 'mask'),
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

    def __call__(self, data):
        d = dict(data)
        
        tensor_xx = d['tensor_xx']
        tensor_yy = d['tensor_yy']
        tensor_zz = d['tensor_zz']
        tensor_xy = d['tensor_xy']
        tensor_xz = d['tensor_xz']
        tensor_yz = d['tensor_yz']
        b0 = d['b0']
        b1000x = d['b1000x']
        b1000y = d['b1000y']
        b1000z = d['b1000z']
        mask = d['mask']
        b1000xyz_dir = d['b1000xyz_dir']
        minmaxv = d['minmaxv']
        
        tensor_mapping_index_xx = tensor_mapping['Dxx']
        tensor_mapping_index_yy = tensor_mapping['Dyy']
        tensor_mapping_index_zz = tensor_mapping['Dzz']
        tensor_mapping_index_xy = tensor_mapping['Dxy']
        tensor_mapping_index_xz = tensor_mapping['Dxz']
        tensor_mapping_index_yz = tensor_mapping['Dyz']
        tensor_minmaxv_xx = tensor_minmaxv_mapping[tensor_mapping_index_xx]
        tensor_minmaxv_yy = tensor_minmaxv_mapping[tensor_mapping_index_yy]
        tensor_minmaxv_zz = tensor_minmaxv_mapping[tensor_mapping_index_zz]
        tensor_minmaxv_xy = tensor_minmaxv_mapping[tensor_mapping_index_xy]
        tensor_minmaxv_xz = tensor_minmaxv_mapping[tensor_mapping_index_xz]
        tensor_minmaxv_yz = tensor_minmaxv_mapping[tensor_mapping_index_yz]
        minv_tensor_xx, maxv_tensor_xx = tensor_minmaxv_xx[0], tensor_minmaxv_xx[1]
        minv_tensor_yy, maxv_tensor_yy = tensor_minmaxv_yy[0], tensor_minmaxv_yy[1]
        minv_tensor_zz, maxv_tensor_zz = tensor_minmaxv_zz[0], tensor_minmaxv_zz[1]
        minv_tensor_xy, maxv_tensor_xy = tensor_minmaxv_xy[0], tensor_minmaxv_xy[1]
        minv_tensor_xz, maxv_tensor_xz = tensor_minmaxv_xz[0], tensor_minmaxv_xz[1]
        minv_tensor_yz, maxv_tensor_yz = tensor_minmaxv_yz[0], tensor_minmaxv_yz[1]
        
        tensor_xx = (tensor_xx - minv_tensor_xx) / (maxv_tensor_xx - minv_tensor_xx + 1e-8)
        tensor_xx = tensor_xx * 2 - 1
        tensor_xx = tensor_xx * mask
        tensor_xx = torch.clamp(tensor_xx, -2, 2)
        
        tensor_yy = (tensor_yy - minv_tensor_yy) / (maxv_tensor_yy - minv_tensor_yy + 1e-8)
        tensor_yy = tensor_yy * 2 - 1
        tensor_yy = tensor_yy * mask
        tensor_yy = torch.clamp(tensor_yy, -2, 2)
        
        tensor_zz = (tensor_zz - minv_tensor_zz) / (maxv_tensor_zz - minv_tensor_zz + 1e-8)
        tensor_zz = tensor_zz * 2 - 1
        tensor_zz = tensor_zz * mask
        tensor_zz = torch.clamp(tensor_zz, -2, 2)
        
        tensor_xy = (tensor_xy - minv_tensor_xy) / (maxv_tensor_xy - minv_tensor_xy + 1e-8)
        tensor_xy = tensor_xy * 2 - 1
        tensor_xy = tensor_xy * mask
        tensor_xy = torch.clamp(tensor_xy, -2, 2)
        
        tensor_xz = (tensor_xz - minv_tensor_xz) / (maxv_tensor_xz - minv_tensor_xz + 1e-8)
        tensor_xz = tensor_xz * 2 - 1
        tensor_xz = tensor_xz * mask
        tensor_xz = torch.clamp(tensor_xz, -2, 2)
        
        tensor_yz = (tensor_yz - minv_tensor_yz) / (maxv_tensor_yz - minv_tensor_yz + 1e-8)
        tensor_yz = tensor_yz * 2 - 1
        tensor_yz = tensor_yz * mask
        tensor_yz = torch.clamp(tensor_yz, -2, 2)
        
        minv_b0, maxv_b0 = 0.0, minmaxv[1]
        minv_b1000, maxv_b1000 = 0.0, minmaxv[3]
        
        b0 = (b0 - minv_b0) / (maxv_b0 - minv_b0 + 1e-8)
        b0 = b0 * 2 - 1
        b0 = b0 * mask
        b0 = torch.clamp(b0, -1, 1)
        
        b1000x = (b1000x - minv_b1000) / (maxv_b1000 - minv_b1000 + 1e-8)
        b1000x = b1000x * 2 - 1
        b1000x = b1000x * mask
        b1000x = torch.clamp(b1000x, -1, 1)
        
        b1000y = (b1000y - minv_b1000) / (maxv_b1000 - minv_b1000 + 1e-8)
        b1000y = b1000y * 2 - 1
        b1000y = b1000y * mask
        b1000y = torch.clamp(b1000y, -1, 1)
        
        b1000z = (b1000z - minv_b1000) / (maxv_b1000 - minv_b1000 + 1e-8)
        b1000z = b1000z * 2 - 1
        b1000z = b1000z * mask
        b1000z = torch.clamp(b1000z, -1, 1)
        
        d['tensor_xx'] = tensor_xx
        d['tensor_yy'] = tensor_yy
        d['tensor_zz'] = tensor_zz
        d['tensor_xy'] = tensor_xy
        d['tensor_xz'] = tensor_xz
        d['tensor_yz'] = tensor_yz
        d['tensor_mapping_index_xx'] = tensor_mapping_index_xx
        d['tensor_mapping_index_yy'] = tensor_mapping_index_yy
        d['tensor_mapping_index_zz'] = tensor_mapping_index_zz
        d['tensor_mapping_index_xy'] = tensor_mapping_index_xy
        d['tensor_mapping_index_xz'] = tensor_mapping_index_xz
        d['tensor_mapping_index_yz'] = tensor_mapping_index_yz
        d['b0'] = b0
        d['b1000x'] = b1000x
        d['b1000y'] = b1000y
        d['b1000z'] = b1000z
        
        return d
    
    
def get_train_transforms():
    train_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", "b0", "b1000x", "b1000y", "b1000z", "mask"], channel_dim='no_channel'),
            transforms.RandFlipd(keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", "b0", "b1000x", "b1000y", "b1000z", "mask"], prob=0.5, spatial_axis=0),
            transforms.RandAffined(
                keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", "b0", "b1000x", "b1000y", "b1000z", "mask"],
                mode=("bilinear", "bilinear", "bilinear", "bilinear", "bilinear", "bilinear", "bilinear", "bilinear", "bilinear", "bilinear", "nearest"),
                prob=0.9,
                rotate_range=((-np.pi/36, np.pi/36), (-np.pi/36, np.pi/36), (-np.pi/36, np.pi/36)),
                translate_range=None,
                scale_range=(0.001, 0.001),
                padding_mode="border",
            ),
            LastTransform(),
            transforms.ToTensord(keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", 
                                       "b0", "b1000x", "b1000y", "b1000z", "mask", 
                                       "tensor_mapping_index_xx", "tensor_mapping_index_yy", "tensor_mapping_index_zz", 
                                       "tensor_mapping_index_xy", "tensor_mapping_index_xz", "tensor_mapping_index_yz"]),
        ]
    )
    return train_transform

def get_val_transforms():
    val_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", "b0", "b1000x", "b1000y", "b1000z", "mask"], channel_dim='no_channel'),
            LastTransform(),
            transforms.ToTensord(keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", 
                                       "b0", "b1000x", "b1000y", "b1000z", "mask", 
                                       "tensor_mapping_index_xx", "tensor_mapping_index_yy", "tensor_mapping_index_zz", 
                                       "tensor_mapping_index_xy", "tensor_mapping_index_xz", "tensor_mapping_index_yz"]),
        ]
    )
    return val_transform

def get_test_transforms():
    val_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", "b0", "b1000x", "b1000y", "b1000z", "mask"], channel_dim='no_channel'),
            LastTransform(),
            transforms.ToTensord(keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", 
                                       "b0", "b1000x", "b1000y", "b1000z", "mask", 
                                       "tensor_mapping_index_xx", "tensor_mapping_index_yy", "tensor_mapping_index_zz", 
                                       "tensor_mapping_index_xy", "tensor_mapping_index_xz", "tensor_mapping_index_yz"]),
        ]
    )
    return val_transform


