import os
import re
import math
import torch
import numpy as np
from PIL import Image
from torchvision.utils import make_grid
import torch.nn.functional as F

def tensor2np_img(img_tensor, grid_style=True, nrow=None, dtype='uint8'):
    dim = img_tensor.dim()
    img_tensor = img_tensor.detach().cpu()
    img_tensor = (img_tensor + 1.0) / 2.0     # convert to [0,1]
    
    if dim == 4:
        # tensor -> grid image numpy(RGB)
        if grid_style:
            batch_size = img_tensor.size()[0]
            if nrow is None:
                img_np = make_grid(img_tensor, nrow=int(math.sqrt(batch_size)), normalize=False).numpy()
            else:
                img_np = make_grid(img_tensor, nrow=nrow, normalize=False).numpy()
            img_np = np.transpose(img_np, (1,2,0))
            
        else:
            img_tensor = img_tensor.permute(0,2,3,1)
            img_tensor = img_tensor.contiguous()
            img_np = img_tensor.numpy()
    
    elif dim == 3:
        img_np = img_tensor.numpy()
        img_np = np.transpose(img_np, (1,2,0))
    
    elif dim == 2:
        img_np = img_tensor.numpy()
    
    
    # convert into RGB Image
    if dtype == 'int32':
        img_np = (img_np * 255.0).round().astype(np.int32)
    else:
        img_np = (img_np * 255.0).round().astype(np.uint8)

    return img_np

def save_np_img(img_np, img_path, mode='RGB'):
    if img_np.shape[-1] == 1:
        mode = 'L'
        img_np = img_np.reshape(img_np.shape[0], img_np.shape[1])
        
    img_pil = Image.fromarray(img_np, mode=mode)
    img_pil.save(img_path)


def softmax_with_temperature(z, T):
    z = z / T
    maxz, _ = torch.max(z, dim=-1)
    z = z - maxz.unsqueeze(-1).detach()
    y = F.softmax(z, dim=1)
    return y
    
