import sys
import torch.nn as nn
import torch
import numpy as np
from scipy import signal
import torch.fft as fft
import torchvision
import torch.nn.functional as f
import torch.nn.functional as F
import wandb
import matplotlib.pyplot as plt
from PIL import Image
from torchmetrics.image import StructuralSimilarityIndexMeasure
from IPython.display import display, clear_output
import math
import random
alpha = 1

def seed_everything(seed: int = 47):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
def wandb_log(loglist, epoch, note, cmap='gray', target=[], colorbar=False, accelerator=None):
    for key, val in loglist.items():
        try:
            try: item = val.cpu().detach()
            except: item = val
            if key in target:
                try: log = log_colormap(item, cmap=cmap, colorbar=colorbar)
                except: log = wandb.Image(item)
            else: 
                log = wandb.Image(item)
        except:
            log = val
        if accelerator is not None:
            # if accelerator.is_main_process:
            accelerator.log({
                "{0} {1}".format(note.capitalize(), key.capitalize()): log,
            }, step=epoch)
        else:
            wandb.log({
                "{0} {1}".format(note.capitalize(), key.capitalize()): log,
            }, step=epoch)


def log_colormap(image, cmap='gray', colorbar=False):
    image = image.numpy()
    if image.ndim == 4 and image.shape[1] == 1:
        image = image[0, 0]
    elif image.ndim == 3 and image.shape[0] == 1:
        image = image[0]
    fig, ax = plt.subplots()
    cax = ax.imshow(image, cmap=cmap)
    if colorbar:
        fig.colorbar(cax)

    buf = BytesIO()
    plt.savefig(buf, format='png')
    plt.close(fig)
    buf.seek(0)
    buf = Image.open(buf)
    return wandb.Image(buf)


def preplot_t(image):
    image = torchvision.transforms.Resize((270, 480))(image)
    out_image = image.flip(-2)
    return out_image[..., 60:, 62:-38]

def gaussian_window(size, fwhm):
    with torch.no_grad():
        sigma = size / fwhm
        x = torch.arange(size) - (size - 1) / 2
        gauss = torch.exp(-0.5 * (x / sigma) ** 2)
    return gauss.detach()

def gaus_tk(x, fwhm=3):
    b, c, k, w, h = x.size()
    device = x.device
    dtype = x.dtype

    ga_w = gaussian_window(w, fwhm)
    ga_h = gaussian_window(h, fwhm)

    ga = ga_w.unsqueeze(1) * ga_h.unsqueeze(0)
    ga = ga.unsqueeze(0).unsqueeze(0).unsqueeze(0)  # (1, 1, w, h)
    ga = ga.expand(b, c, k, w, h)  

    return x * ga.to(x.device)
def gaus_t(x, fwhm=3):
    b, c, w, h = x.size()
    device = x.device
    dtype = x.dtype

    ga_w = gaussian_window(w, fwhm)
    ga_h = gaussian_window(h, fwhm)

    ga = ga_w.unsqueeze(1) * ga_h.unsqueeze(0)
    ga = ga.unsqueeze(0).unsqueeze(0)  # (1, 1, w, h)
    ga = ga.expand(b, c, w, h) 

    return x * ga.to(x.device)


def gaus(x, fwhm=2):
    b, c, w, h = x.size()
    ga = np.outer(signal.windows.gaussian(w, w/fwhm),
                 signal.windows.gaussian(h, h/fwhm))
    ga = torch.tensor(ga)
    ga = ga.repeat(1, c, 1, 1).type(torch.FloatTensor).to(x.device)
    return x*ga


def deconvolution_torch(raw, psf, alpha=1e3):
    psf_ft = torch.fft.fft2(psf)
    psf_ft = torch.conj(psf_ft) / (abs(psf_ft)**2 + alpha)
    raw_ft = torch.fft.fftn(raw, dim=(-2, -1))
    recon = torch.fft.ifftshift(torch.fft.ifftn(raw_ft * psf_ft, dim=(-2, -1)), dim=(-2,-1))
    recon = torch.real(recon)
    return recon

def convolution_torch(img, psf):
    pad_H = psf.size(-2) - 1
    pad_W = psf.size(-1) - 1

    # Apply symmetric padding to x
    img_pad = F.pad(img, (pad_W // 2, pad_W - pad_W // 2,
                         pad_H // 2, pad_H - pad_H // 2), mode='constant')
    psf_pad = F.pad(psf, (pad_W // 2, pad_W - pad_W // 2,
                         pad_H // 2, pad_H - pad_H // 2), mode='constant')
    img_ft = torch.fft.rfft2(img_pad, dim=(-2, -1), norm='ortho').to(img.device)  # Shape: (B, C, H_freq, W_freq)
    psf_ft = torch.fft.rfft2(psf_pad, s=(img_pad.size(-2), img_pad.size(-1)), dim=(-2, -1), norm='ortho').to(img.device)  # Shape: (B, C, H_freq, W_freq)
    raw = torch.fft.ifftshift(torch.fft.irfft2(img_ft * psf_ft, dim=(-2, -1), norm='ortho'), dim=(-2,-1))
    raw = torchvision.transforms.CenterCrop((psf.size(-2), psf.size(-1)))(raw)
    return raw.real

def fft_torch(inputs):
    b, c, h, w = inputs.size()
    fft = torch.fft.fftshift(torch.fft.fftn(inputs, dim=(-2,-1)), dim=(-2,-1) )
    fft = torch.view_as_real(fft)
    fft = fft.reshape(b, c * 2, h, w)
    return fft

def ifft_torch(inputs):
    b, c, h, w = inputs.size()
    inputs = inputs.reshape(b, c//2, h, w, 2)
    inputs = torch.view_as_complex(inputs)
    ifft = abs(torch.fft.ifftn(torch.fft.ifftshift(inputs, dim=(-2,-1)), dim=(-2,-1)))
    return ifft


def PSNR(img1, img2):
    mse = torch.mean((img1 - img2) ** 2, dim=(1,2,3), keepdim=True)
    if torch.mean(mse) == 0:
        return "Same Image"
    return torch.mean(10 * torch.log10(1. / mse))


# def SSIM(img1, img2):
#     _ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(img1.device)
#     return _ssim(img1, img2)


def generate_roi(C, H, W,
                 mode='grid',        
                 max_angle_deg=60.0  
                ):
    if mode == 'grid':
        max_rois = 112# min(64, (H * W) // (4 * 64))
        N_rois = max(min(C, max_rois), 1)    
        N_rois_tensor = torch.tensor(N_rois, dtype=torch.float32)

        N_rois_x = torch.ceil(torch.sqrt(N_rois_tensor)).int().item()
        N_rois_y = torch.ceil(N_rois_tensor / N_rois_x).int().item()
        Total_ROIs = N_rois_x * N_rois_y
    
        N_rois_x = min(N_rois_x, W)
        N_rois_y = min(N_rois_y, H)
        Total_ROIs = N_rois_x * N_rois_y
    
        x_coords = torch.linspace(0, W - 1, steps=N_rois_x)
        y_coords = torch.linspace(0, H - 1, steps=N_rois_y)
        xs, ys = torch.meshgrid(x_coords, y_coords, indexing='ij')
        centers_x = xs.t().flatten() 
        centers_y = ys.t().flatten()
    
        if Total_ROIs > C:
            centers_x = centers_x[:C]
            centers_y = centers_y[:C]
            Total_ROIs = C
        elif Total_ROIs < C:
            repeats = -(-C // Total_ROIs) 
            centers_x = centers_x.repeat(repeats)[:C]
            centers_y = centers_y.repeat(repeats)[:C]
            Total_ROIs = C 
    
        y = torch.arange(0, H, dtype=torch.float32)
        x = torch.arange(0, W, dtype=torch.float32)
        grid_y, grid_x = torch.meshgrid(y, x, indexing='ij')
        grid_x = grid_x.unsqueeze(0)  # (1, H, W)
        grid_y = grid_y.unsqueeze(0)  # (1, H, W)
    
        centers_x = centers_x.view(-1, 1, 1)  # (C, 1, 1)
        centers_y = centers_y.view(-1, 1, 1)  # (C, 1, 1)
    
        sigma_x = W / (2 * N_rois_x)
        sigma_y = H / (2 * N_rois_y)
    
        if sigma_x == 0:
            sigma_x = 1e-6
        if sigma_y == 0:
            sigma_y = 1e-6
    
        dx = grid_x - centers_x  # (C, H, W)
        dy = grid_y - centers_y  # (C, H, W)
    
        weights = torch.exp(-((dx ** 2) / (2 * sigma_x ** 2) + (dy ** 2) / (2 * sigma_y ** 2)))  # (C, H, W)

    elif mode == 'radial':
        y = torch.arange(H, dtype=torch.float32).view(1, H, 1)
        x = torch.arange(W, dtype=torch.float32).view(1, 1, W)
        cx0, cy0 = (W-1)/2.0, (H-1)/2.0
        rr = torch.sqrt((x - cx0)**2 + (y - cy0)**2)  # (1,H,W)
        max_r = rr.max()
    
        i = torch.arange(C, dtype=torch.float32)      # 0,1,...,C-1
        r_centers = torch.sqrt((i + 0.5) / C) * max_r   # (C,)
        r_centers = r_centers.view(C, 1, 1)           # (C,1,1)
    
        sigma_r = max_r / (2.0 * C)
    
        dr = rr - r_centers     # (C,H,W)
        weights = torch.exp(-0.5 * (dr / sigma_r)**2)

    else:
        raise ValueError("mode must be 'grid' or 'radial'")

    weights = weights / weights.sum(dim=0, keepdim=True)

    dev = torch.max(torch.abs(weights.sum(dim=0) - 1.0)).item()
    print(f"Max deviation after normalization: {dev:.3e}")

    return weights

def get_center_bbox(img, threshold=0.0):
    """
    img: torch.Tensor of shape (H, W) or (C, H, W)
    return: center_y, center_x, height, width
    """
    mask = img > threshold
    
    # nonzero indices
    nonzero_indices = mask.nonzero(as_tuple=False)  # (N, 2)


    y_min = nonzero_indices[:, 2].min()
    y_max = nonzero_indices[:, 2].max()
    x_min = nonzero_indices[:, 3].min()
    x_max = nonzero_indices[:, 3].max()

    height = (y_max - y_min + 1).item()
    width = (x_max - x_min + 1).item()
    center_y = (y_min + y_max).div(2, rounding_mode='floor').item()
    center_x = (x_min + x_max).div(2, rounding_mode='floor').item()

    return center_y, center_x, height, width
    
def crop_from_center_bboxes(img, bboxes):
    """
    img: torch.Tensor of shape (B, C, H, W)
    bboxes: torch.Tensor of shape (B, 4)  # (center_y, center_x, height, width) per batch
    return: torch.Tensor of cropped images, shape (B, C, h, w)
            h, w are same across batch (from height, width in bbox)
    """
    B, C, H, W = img.shape
    crops = []
    for b in range(B):
        center_y, center_x, height, width = map(int, bboxes[b])
        y1 = max(0, center_y - height // 2)
        y2 = min(H, y1 + height)
        x1 = max(0, center_x - width // 2)
        x2 = min(W, x1 + width)
        crops.append(img[b:b+1, :, y1:y2, x1:x2])  # keep batch dim
    # pad to same size if edge clipping made them smaller
    crops = torch.cat([
        torch.nn.functional.pad(c, (0, width - c.shape[-1], 0, height - c.shape[-2]))
        for c in crops
    ], dim=0)
    return crops

def crop_from_center_bbox(img, bbox):
    """
    img: torch.Tensor of shape (H, W) or (C, H, W)
    """
    center_y, center_x, height, width = bbox
    y1 = center_y - height // 2
    y2 = y1 + height
    x1 = center_x - width // 2
    x2 = x1 + width

    B, C, H, W = img.shape

    y1 = max(0, y1)
    x1 = max(0, x1)
    y2 = min(H, y2)
    x2 = min(W, x2)

    return img[..., y1:y2, x1:x2]
