from typing import Optional, Tuple
import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

from .util import bchw2hwc


def set_figsize(*args):
    if len(args) == 0:
        plt.rcParams["figure.figsize"] = plt.rcParamsDefault["figure.figsize"]
    elif len(args) == 1:
        plt.rcParams["figure.figsize"] = (args[0], args[0])
    elif len(args) == 2:
        plt.rcParams["figure.figsize"] = tuple(args)
    else:
        raise RuntimeError(
            f'Supported argument types: set_figsize() or set_figsize(int) or set_figsize(int, int)')


def show_hwc(image: torch.Tensor):
    if image.dtype != torch.uint8:
        image = image.to(torch.uint8)
    if image.size(2) == 1:
        image = image.repeat(1, 1, 3)
    pimage = Image.fromarray(image.cpu().numpy())
    plt.imshow(pimage)
    plt.show()


def show_bchw(image: torch.Tensor):
    show_hwc(bchw2hwc(image))


def show_bhw(image: torch.Tensor):
    show_bchw(image.unsqueeze(1))

# add
def get_bhw(image: torch.Tensor):
    return get_bchw(image.unsqueeze(1))

def get_bchw(image: torch.Tensor):
    result = bchw2hwc(image)
    if result.dtype != torch.uint8:
        result = result.to(torch.uint8)
    if result.size(2) == 1:
        # r = torch.full(result.size(), 128).to(torch.uint8).to(result.device)
        # b = torch.full(result.size(), 128).to(torch.uint8).to(result.device)
        # result = torch.cat([result, r, b], dim=2)
        result = result.repeat(1, 1, 3)
    return result

def test():
    return True

def get_bhw_no_contour(image: torch.Tensor):
    return get_bchw_no_contour(image.unsqueeze(1))

def get_bchw_no_contour(image: torch.Tensor):
    result = bchw2hwc(image)
    if result.dtype != torch.uint8:
        result = result.to(torch.uint8)
    if result.size(2) == 1:
        result = result.repeat(1, 1, 3)
    result = result.cpu().numpy()
    old_values = [
        [23, 23, 23], 
        [231, 231, 231]
    ]
    new_value = [0, 0, 0]
    for old_value in old_values:
        mask = np.all(result == old_value, axis=-1)
        result[mask] = new_value
    return result