import torch
import os
from jaxtyping import jaxtyped, Float, Shaped, Int
from beartype import beartype
import numpy as np
from PIL import Image
import glob
import torch.nn.functional as F

from conf.evaluation_params import EvaluationParams
from utils.utils import display_tensor, rgb2mask


@jaxtyped(typechecker=beartype)
def controlnet_blender_0_12(
        params: EvaluationParams,
        idx: int,
) -> tuple[
    Float[torch.Tensor, "3 h w"],
    Float[torch.Tensor, "3 h w"],
    Float[torch.Tensor, "3 h w"],
]:
    # get the paths
    dom1_folder, dom2_folder = params.folder_predictions
    dom1_path = os.path.join(dom1_folder, f"dom_1_{idx}.png")
    dom2_path = os.path.join(dom2_folder, f"dom_2_{idx}.png")

    # load the images
    dom1 = Image.open(dom1_path)
    dom1 = torch.from_numpy(np.array(dom1)) / 255 * 2 - 1
    dom1 = dom1.permute(2, 0, 1)

    dom2 = Image.open(dom2_path)
    dom2 = torch.from_numpy(np.array(dom2)) / 255 * 2 - 1
    dom2 = dom2.permute(2, 0, 1)
    
    # copy others
    dom0 = torch.randn_like(dom1).clamp(-1, 1)

    return dom0, dom1, dom2


@jaxtyped(typechecker=beartype)
def controlnet_brats2020(
        params: EvaluationParams,
        idx: int,
) -> tuple[
    Float[torch.Tensor, "1 h w"],
    Float[torch.Tensor, "1 h w"],
    Float[torch.Tensor, "1 h w"],
    Float[torch.Tensor, "1 h w"],
    Int[torch.Tensor, "2 h w"],
]:
    # get the paths
    folder_predictions = params.folder_predictions[0]
    t1 = os.path.join(folder_predictions, f"dom_0_{idx}.png")
    t2 = os.path.join(folder_predictions, f"dom_1_{idx}.png")
    t1ce = os.path.join(folder_predictions, f"dom_2_{idx}.png")
    flair = os.path.join(folder_predictions, f"dom_3_{idx}.png")
    seg = os.path.join(folder_predictions, f"dom_4_{idx}.png")

    # load the images
    t1 = np.array(Image.open(t1).convert('RGB'))
    t1ce = np.array(Image.open(t1ce).convert('RGB'))
    t2 = np.array(Image.open(t2).convert('RGB'))
    flair = np.array(Image.open(flair).convert('RGB'))
    seg = np.array(Image.open(seg).convert('RGB'))
    
    # convert to tensor
    t1 = torch.from_numpy(t1).permute(2, 0, 1)  # channels first
    t1ce = torch.from_numpy(t1ce).permute(2, 0, 1)  # channels first
    t2 = torch.from_numpy(t2).permute(2, 0, 1)  # channels first
    flair = torch.from_numpy(flair).permute(2, 0, 1)  # channels first
    seg = torch.from_numpy(seg).permute(2, 0, 1)  # channels first

    return t1, t1ce, t2, flair, seg


@jaxtyped(typechecker=beartype)
def controlnet_brats2020_1234_0(
        params: EvaluationParams,
        idx: int,
) -> tuple[
    Float[torch.Tensor, "1 h w"],
    Float[torch.Tensor, "1 h w"],
    Float[torch.Tensor, "1 h w"],
    Float[torch.Tensor, "1 h w"],
    Float[torch.Tensor, "2 h w"],
]:
    # get the paths
    folder_predictions = params.folder_predictions[0]
    t1_path = os.path.join(folder_predictions, f"dom_0_{idx}.png")

    # load the images
    t1 = Image.open(t1_path)
    t1 = torch.from_numpy(np.array(t1)) / 255
    t1 = t1.unsqueeze(0)
    
    # copy others
    t1ce = torch.randn_like(t1).clamp(0, 1)
    t2 = torch.randn_like(t1).clamp(0, 1)
    flair = torch.randn_like(t1).clamp(0, 1)
    seg = torch.randn_like(t1).clamp(0, 1)

    # seg has to be segmentation with 2 channels and Int
    seg = (seg < 0.5).float()
    seg = seg.reshape(1, 256, 256).repeat(2, 1, 1)
    seg[1] = 1 - seg[0]

    return t1, t1ce, t2, flair, seg


@jaxtyped(typechecker=beartype)
def controlnet_celeba_12_0(
        params: EvaluationParams,
        idx: int,
) -> tuple[
    Float[torch.Tensor, "3 h w"],
    Float[torch.Tensor, "1 h w"],
    Float[torch.Tensor, "19 h w"],
]:
    # get the paths
    folder_predictions = params.folder_predictions[0]
    face_path = os.path.join(folder_predictions, f"dom_0_{idx}.png")

    # load the images
    face = Image.open(face_path)
    face = torch.from_numpy(np.array(face)) / 255 * 2 - 1
    face = face.permute(2, 0, 1)
    
    # copy others
    sketch = torch.randn_like(face).clamp(-1, 1)[:1]
    seg = torch.randint_like(input=face[0], low=0, high=19).long()
    # one hot encoding
    seg = F.one_hot(seg, num_classes=19).permute(2, 0, 1).float()

    return face, sketch, seg

@jaxtyped(typechecker=beartype)
def get_prediction_from(
    params: EvaluationParams,
    idx: int,
) -> tuple[Float[torch.Tensor, "_ h w"], ...]:
    if params.loadfunction == "controlnet_brats2020_1234_0":
        return controlnet_brats2020_1234_0(params, idx)
    elif params.loadfunction == "controlnet_celeba_12_0":
        return controlnet_celeba_12_0(params, idx)
    elif params.loadfunction == "controlnet_blender_0_12":
        return controlnet_blender_0_12(params, idx)
    else:
        raise ValueError(f"Unknown get_prediction_from: {params.loadfunction}")
