import os
from typing import Dict, List

import numpy as np
import torch
import torch.nn.functional as F
import wandb
from PIL import Image

from fourm.utils.data_constants import SEG_IGNORE_INDEX
from fourm.utils.misc import denormalize
from datasets_semseg import procthor_classes, replica_classes, scannet_pp_classes



@torch.no_grad()
def log_semseg_wandb(
        images: torch.Tensor, 
        preds: List[np.ndarray], 
        gts: List[np.ndarray],
        dataset_name: str = 'scannet_pp',
        image_count=8, 
        prefix="",
        ignore_index=SEG_IGNORE_INDEX
    ):

    if dataset_name == 'procthor':
        classes = procthor_classes()
    elif dataset_name == 'replica':
        classes = replica_classes()
    elif dataset_name == 'scannet_pp':
        classes = scannet_pp_classes()
    else:
        raise ValueError(f'Dataset {dataset_name} not supported for logging to wandb.')

    class_labels = {i: cls for i, cls in enumerate(classes)}
    class_labels[len(classes)] = "void"
    class_labels[ignore_index] = "ignore"

    image_count = min(len(images), image_count)

    images = images[:image_count]
    preds = preds[:image_count]
    gts = gts[:image_count]

    semseg_images = {}

    for i, (image, pred, gt) in enumerate(zip(images, preds, gts)):
        image = denormalize(image)
        pred[gt == SEG_IGNORE_INDEX] = SEG_IGNORE_INDEX

        semseg_image = wandb.Image(image, masks={
            "predictions": {
                "mask_data": pred,
                "class_labels": class_labels,
            },
            "ground_truth": {
                "mask_data": gt,
                "class_labels": class_labels,
            }
        })

        semseg_images[f"{prefix}_{i}"] = semseg_image

    wandb.log(semseg_images, commit=False)
    

def colorize_semseg(index_map, num_classes):
    rng = np.random.default_rng(seed=0)
    pallete = [rng.integers(0, 255, size=3) for i in range(num_classes + 1)]
    
    colorred_map = np.zeros((index_map.shape[0], index_map.shape[1], 3), dtype=np.uint8)
    for idx in np.unique(index_map):
        colorred_map[index_map == idx] = pallete[idx]
    
    return colorred_map

def concatenate_and_save(image, gt, pred, save_dir, file_name, postfix='visual_results'):
    save_dir = os.path.join(save_dir, postfix)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Check if sizes are different and resize gt and pred to match image
    image_height, image_width = image.shape[:2]

    if gt.shape[:2] != (image_height, image_width):
        gt = np.array(Image.fromarray(gt).resize((image_width, image_height)))

    if pred.shape[:2] != (image_height, image_width):
        pred = np.array(Image.fromarray(pred).resize((image_width, image_height)))
    
    # Concatenate the images horizontally
    concatenated_image = np.concatenate((image, gt, pred), axis=1)

    # Convert the numpy array to an Image object
    concatenated_image_pil = Image.fromarray(concatenated_image)

    # Save the image
    save_path = os.path.join(save_dir, file_name)
    concatenated_image_pil.save(save_path)
        
def save_semseg_preds(
    images: List[torch.Tensor], 
    preds: List[np.ndarray], 
    gts: List[np.ndarray],
    save_dir: str,
    dataset_name: str = 'scannet_pp',
    image_count=None,
    colorize=True,
):
    
    if dataset_name == 'procthor':
        num_classes = len(procthor_classes())
    elif dataset_name == 'replica':
        num_classes = len(replica_classes())
    elif dataset_name == 'scannet_pp':
        num_classes = len(scannet_pp_classes())
    else:
        raise ValueError(f'Dataset {dataset_name} not supported for logging.')
    
    image_count = len(images) if image_count is None else min(len(images), image_count)

    images = images[:image_count]
    preds = preds[:image_count]
    gts = gts[:image_count]

    for i, (image, pred, gt) in enumerate(zip(images, preds, gts)):
        image = denormalize(image)
        image = (image.permute(1, 2, 0) * 255.0).numpy().astype(np.uint8)
        if colorize:
            pred[gt == SEG_IGNORE_INDEX] = num_classes
            gt[gt == SEG_IGNORE_INDEX] = num_classes

            pred = colorize_semseg(pred, num_classes)
            gt = colorize_semseg(gt, num_classes)
            
            concatenate_and_save(image, gt, pred, save_dir, f'{i}.png', postfix='visual_results')
        else:
            pred[gt == SEG_IGNORE_INDEX] = 255
            gt[gt == SEG_IGNORE_INDEX] = 255
            
            pred = np.stack((pred,)*3, axis=-1).astype(np.uint8)
            gt = np.stack((gt,)*3, axis=-1).astype(np.uint8)
        
            concatenate_and_save(image, gt, pred, save_dir, f'{i}.png', postfix='visual_results_raw')