import random
import numpy as np
import torch
from monai import transforms

tensor_mapping = {
    'Dxx': 0,
    'Dyy': 1,
    'Dzz': 2,
    'Dxy': 3,
    'Dxz': 4,
    'Dyz': 5,
}
tensor_minmaxv_mapping = [[0.0, 0.002], [0.0, 0.002], [0.0, 0.002], [-0.0005, 0.0005], [-0.0005, 0.0005], [-0.0005, 0.0005]]

class LastTransform(transforms.MapTransform):
    def __init__(
        self,
        keys = ('x', 'y', 'mask'),
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)

    def __call__(self, data):
        d = dict(data)
        
        tensor_xx = d['tensor_xx']
        tensor_yy = d['tensor_yy']
        tensor_zz = d['tensor_zz']
        tensor_xy = d['tensor_xy']
        tensor_xz = d['tensor_xz']
        tensor_yz = d['tensor_yz']
        b0 = d['b0']
        b1000x = d['b1000x']
        b1000y = d['b1000y']
        b1000z = d['b1000z']
        mask = d['mask']
        b1000xyz_dir = d['b1000xyz_dir']
        minmaxv = d['minmaxv']
        
        tensor_xx_mapping_index = tensor_mapping['Dxx']
        tensor_xx_minmaxv = tensor_minmaxv_mapping[tensor_xx_mapping_index]
        tensor_yy_mapping_index = tensor_mapping['Dyy']
        tensor_yy_minmaxv = tensor_minmaxv_mapping[tensor_yy_mapping_index]
        tensor_zz_mapping_index = tensor_mapping['Dzz']
        tensor_zz_minmaxv = tensor_minmaxv_mapping[tensor_zz_mapping_index]
        tensor_xy_mapping_index = tensor_mapping['Dxy']
        tensor_xy_minmaxv = tensor_minmaxv_mapping[tensor_xy_mapping_index]
        tensor_xz_mapping_index = tensor_mapping['Dxz']
        tensor_xz_minmaxv = tensor_minmaxv_mapping[tensor_xz_mapping_index]
        tensor_yz_mapping_index = tensor_mapping['Dyz']
        tensor_yz_minmaxv = tensor_minmaxv_mapping[tensor_yz_mapping_index]
        
        tensor_xx = (tensor_xx - tensor_xx_minmaxv[0]) / (tensor_xx_minmaxv[1] - tensor_xx_minmaxv[0] + 1e-8)
        tensor_xx = tensor_xx * 2 - 1
        tensor_xx = tensor_xx * mask
        tensor_xx = torch.clamp(tensor_xx, -2, 2)
        
        tensor_yy = (tensor_yy - tensor_yy_minmaxv[0]) / (tensor_yy_minmaxv[1] - tensor_yy_minmaxv[0] + 1e-8)
        tensor_yy = tensor_yy * 2 - 1
        tensor_yy = tensor_yy * mask
        tensor_yy = torch.clamp(tensor_yy, -2, 2)
        
        tensor_zz = (tensor_zz - tensor_zz_minmaxv[0]) / (tensor_zz_minmaxv[1] - tensor_zz_minmaxv[0] + 1e-8)
        tensor_zz = tensor_zz * 2 - 1
        tensor_zz = tensor_zz * mask
        tensor_zz = torch.clamp(tensor_zz, -2, 2)
        
        tensor_xy = (tensor_xy - tensor_xy_minmaxv[0]) / (tensor_xy_minmaxv[1] - tensor_xy_minmaxv[0] + 1e-8)
        tensor_xy = tensor_xy * 2 - 1
        tensor_xy = tensor_xy * mask
        tensor_xy = torch.clamp(tensor_xy, -2, 2)
        
        tensor_xz = (tensor_xz - tensor_xz_minmaxv[0]) / (tensor_xz_minmaxv[1] - tensor_xz_minmaxv[0] + 1e-8)
        tensor_xz = tensor_xz * 2 - 1
        tensor_xz = tensor_xz * mask
        tensor_xz = torch.clamp(tensor_xz, -2, 2)
        
        tensor_yz = (tensor_yz - tensor_yz_minmaxv[0]) / (tensor_yz_minmaxv[1] - tensor_yz_minmaxv[0] + 1e-8)
        tensor_yz = tensor_yz * 2 - 1
        tensor_yz = tensor_yz * mask
        tensor_yz = torch.clamp(tensor_yz, -2, 2)
        
        minv_b0, maxv_b0 = 0.0, minmaxv[1]
        minv_b1000, maxv_b1000 = 0.0, minmaxv[3]
        
        b0 = (b0 - minv_b0) / (maxv_b0 - minv_b0 + 1e-8)
        b0 = b0 * 2 - 1
        b0 = b0 * mask
        b0 = torch.clamp(b0, -1, 1)
        
        b1000x = (b1000x - minv_b1000) / (maxv_b1000 - minv_b1000 + 1e-8)
        b1000x = b1000x * 2 - 1
        b1000x = b1000x * mask
        b1000x = torch.clamp(b1000x, -1, 1)
        
        b1000y = (b1000y - minv_b1000) / (maxv_b1000 - minv_b1000 + 1e-8)
        b1000y = b1000y * 2 - 1
        b1000y = b1000y * mask
        b1000y = torch.clamp(b1000y, -1, 1)
        
        b1000z = (b1000z - minv_b1000) / (maxv_b1000 - minv_b1000 + 1e-8)
        b1000z = b1000z * 2 - 1
        b1000z = b1000z * mask
        b1000z = torch.clamp(b1000z, -1, 1)
        
        d['inputs'] = torch.stack([b0, b1000x, b1000y, b1000z, tensor_xx, tensor_yy, tensor_zz, tensor_xy, tensor_xz, tensor_yz], dim=0)
        d['vq_mapping_index'] = torch.tensor([2, 3, 3, 3, 0, 0, 0, 1, 1, 1])
        d['diffusion_mapping_index'] = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
        
        return d
    
    
def get_train_transforms():
    train_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", "b0", "b1000x", "b1000y", "b1000z", "mask"], channel_dim='no_channel'),
            # transforms.RandFlipd(keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", "b0", "b1000x", "b1000y", "b1000z", "mask"], prob=0.5, spatial_axis=0),
            transforms.RandAffined(
                keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", "b0", "b1000x", "b1000y", "b1000z", "mask"],
                mode=("bilinear", "bilinear", "bilinear", "bilinear", "bilinear", "bilinear", "nearest", "nearest", "nearest", "nearest", "nearest"),
                prob=0.9,
                rotate_range=((-np.pi/36, np.pi/36), (-np.pi/36, np.pi/36), (-np.pi/36, np.pi/36)),
                translate_range=None,
                scale_range=(0.001, 0.001),
                padding_mode="border",
            ),
            LastTransform(),
            transforms.ToTensord(keys=["inputs", "mask"]),
            transforms.ToTensord(keys=["vq_mapping_index", "diffusion_mapping_index"], dtype=torch.long),
        ]
    )
    return train_transform

def get_val_transforms():
    val_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", "b0", "b1000x", "b1000y", "b1000z", "mask"], channel_dim='no_channel'),
            LastTransform(),
            transforms.ToTensord(keys=["inputs", "mask"]),
            transforms.ToTensord(keys=["vq_mapping_index", "diffusion_mapping_index"], dtype=torch.long),
        ]
    )
    return val_transform

def get_test_transforms():
    val_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", "b0", "b1000x", "b1000y", "b1000z", "mask"], channel_dim='no_channel'),
            LastTransform(),
            transforms.ToTensord(keys=["inputs", "mask"]),
            transforms.ToTensord(keys=["vq_mapping_index", "diffusion_mapping_index"], dtype=torch.long),
        ]
    )
    return val_transform


def get_test_transforms_forward():
    val_transform = transforms.Compose(
        [
            transforms.EnsureChannelFirstd(keys=["tensor_xx", "tensor_yy", "tensor_zz", "tensor_xy", "tensor_xz", "tensor_yz", "b0", "b1000x", "b1000y", "b1000z", "mask"], channel_dim='no_channel'),
            LastTransform(),
            transforms.ToTensord(keys=["inputs", "b0", "b1000x", "b1000y", "b1000z", "mask"]),
            transforms.ToTensord(keys=["vq_mapping_index", "diffusion_mapping_index"], dtype=torch.long),
        ]
    )
    return val_transform
