import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
from matplotlib.colors import LinearSegmentedColormap
import rasterio

def evaluate(model: torch.nn.Module,
                   dataloaders: torch.utils.data.DataLoader,
                   metric: torch.nn.Module,
                   criterion: torch.nn.Module,
                   device: torch.device,
                   mode="val"):
    """
    Evaluate model performance on the dataset.

    Args:
        model (torch.nn.Module): The model to evaluate.
        dataloaders (torch.utils.data.DataLoader): Dataloader for the dataset.
        metric (torch.nn.Module): Evaluation metric.
        criterion (torch.nn.Module): Loss criterion.
        device (torch.device): Device to run the evaluation on.
        mode (str, optional): Evaluation mode ("val" or "test").

    Returns:
        float: Mean loss.
        float: Mean IoU.
    """
    model.eval()
    model.to(device)

    running_ious, running_losses = [], []

    for sample in dataloaders:
        x = sample["image"]
        y = sample["label"]

        inputs = x.to(device)
        targets = y.to(device)

        with torch.no_grad():
            outputs = model(inputs)
            if mode == "test":
                result = torch.softmax(outputs, dim=1)
                result = np.transpose(np.argmax(result.to("cpu"), axis=0))
            loss = criterion(outputs, targets)
            loss_value = loss.item()
            running_losses.append(loss_value)

            iou_value = metric(outputs, targets)
            running_ious.append(iou_value)

    mean_loss = np.mean(running_losses)
    mean_iou = np.mean(running_ious)

    return mean_loss, mean_iou

def create_error_map(mask1, mask2):
    """
    Create an error map by comparing two masks.

    Args:
        mask1 (numpy.ndarray): First mask.
        mask2 (numpy.ndarray): Second mask.

    Returns:
        numpy.ndarray: Error map where 1 indicates mismatch and 0 indicates match.
    """
    if mask1.shape != mask2.shape:
        raise ValueError("Input masks must have the same shape")

    error_map = np.where(np.array(mask1) == np.array(mask2), 1, 0)
    return error_map

def display_images(**images):
    """
    Display a row of images.

    Args:
        images (dict): Dictionary of images to display.
    """
    plt.figure(figsize=(12, 12))
    colors = ["#ffffff", "#b4c6cc", "#438ba3", "#044359"]
    cm = LinearSegmentedColormap.from_list("my_list", colors)
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, len(images), idx + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(name.replace("_", " ").title(), fontsize=15)
        if name == "error_map":
            plt.imshow(image, cmap="gray_r")
        elif "mask" in name:
            plt.imshow(image, cmap=cm)
        else:
            plt.imshow(image)
    plt.show()

def save_geotiff(output_dir, data, geotransform):
    """
    Save data as a GeoTIFF file.

    Args:
        output_dir (str): Output file path.
        data (numpy.ndarray): Data to save.
        geotransform: GeoTransform information.
    """
    crs = rasterio.crs.CRS.from_epsg(32706)
    width, height = 224, 224

    with rasterio.open(output_dir, "w", driver="GTiff", height=height, width=width, count=1, dtype="uint8",
                      crs=crs, transform=geotransform) as dst:
      dst.write(data, 1)

def test_single_volume(image, net, patch_size=(224, 224)):
    """
    Test a single volume using the network.

    Args:
        image (torch.Tensor): Input image.
        net (torch.nn.Module): Neural network model.
        patch_size (tuple, optional): Size of the patch to use for testing.

    Returns:
        numpy.ndarray: Prediction result.
    """
    image = image.squeeze(0).cpu().detach().numpy()
    if len(image.shape) == 3:
        prediction = np.zeros(patch_size)
        for ind in range(image.shape[0]):
            slice = image[ind, :, :]
            x, y = slice.shape[0], slice.shape[1]
            if x != patch_size[0] or y != patch_size[1]:
                slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3)
            input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
            net.eval()
            with torch.no_grad():
                outputs = net(input)
                out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
                out = out.cpu().detach().numpy()
                if x != patch_size[0] or y != patch_size[1]:
                    pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
                else:
                    pred = out
                prediction[ind] = pred
                return prediction
    else:
        input = torch.from_numpy(image).unsqueeze(
            0).unsqueeze(0).float().cuda()
        net.eval()
        with torch.no_grad():
            out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
            prediction = out.cpu().detach().numpy()
            return prediction
