import torch
import os
from jaxtyping import jaxtyped, Float, Shaped
from beartype import beartype
import numpy as np
from PIL import Image
import glob

from conf.evaluation_params import EvaluationParams
from utils.utils import display_tensor, rgb2mask


def get_file(path_to_file: str) -> str:
    file = glob.glob(path_to_file)
    assert len(file) == 1, f"Found {len(file)} files for {path_to_file}, should be 1"
    file = file[0]
    return file


@jaxtyped(typechecker=beartype)
def process_fused_prediction(
    prediction: Shaped[torch.Tensor, "3 h w"],
) -> list[Shaped[torch.Tensor, "_cperdom hr wr"]]:
    height: int = 256
    width: int = 320

    # remove mode strip on the left + back separation on the left
    prediction = prediction[:, :, 10+10:]

    # remove the back strip separating the generation from the Gt on the right
    # prediction = torch.cat([prediction[:, :, :-(width+10)], prediction[:, :, -width:]], dim=-1)
    # remove the GT + black strip at the right
    prediction = prediction[:, :, :-(width+10)]

    # only keep the last column which is the prediction
    prediction = prediction[:, :, -width:]

    nb_dom = prediction.shape[1] // (3 * height)
    pred_res = []
    for i in range(nb_dom):
        curr = prediction[:, i*3*height : i*3*height + height]
        assert curr.shape == (3, height, width)
        pred_res.append(curr)

    pred_res[1] = pred_res[1][0].unsqueeze(0)
    pred_res[2] = rgb2mask(pred_res[-1], normalize_in=None, one_hot=True, nb_class=14)

    return pred_res


@jaxtyped(typechecker=beartype)
def get_prediction_from_fused(
    params: EvaluationParams,
    idx_in_ds: int,
) -> list[Float[torch.Tensor, "_ h w"]]:
    path_to_file = os.path.join(params.folder_predictions, f"{idx_in_ds}.png")
    path_to_file = get_file(path_to_file)
    # load image to numpy array
    prediction = np.array(Image.open(path_to_file).convert('RGB'))
    prediction_tensor = torch.from_numpy(prediction).permute(2, 0, 1)  # channels first

    photo, depth, mask = process_fused_prediction(prediction_tensor)

    # mask is already one hot, so also in [0, 1] it's good.
    # photo and mask should be in [-1,1]
    photo = photo.float() / 255.
    depth = depth.float() / 255.

    # from [0,1] to [-1,1]
    photo = photo * 2 - 1
    depth = depth * 2 - 1

    return [photo, depth, mask]


@jaxtyped(typechecker=beartype)
def get_prediction_from(
    params: EvaluationParams,
    idx_in_ds: int,
) -> list[Float[torch.Tensor, "_ h w"]]:
    if params.get_prediction_from_function == "sunrgbd_fused":
        return get_prediction_from_fused(params, idx_in_ds)
    else:
        raise ValueError(f"Unknown get_prediction_from: {params.get_prediction_from_function}")
