#!/usr/bin/env python3
import platform


import timm
import torch as tt
import torch.nn.functional as F
from PIL import Image  # type: ignore
from rex_xai.input.input_data import Data
from rex_xai.responsibility.prediction import from_pytorch_tensor
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as T
import torch.nn as nn


model = timm.create_model('tf_efficientnetv2_s.in21k_ft_in1k', pretrained=True)
num_features = model.classifier.in_features
model.classifier = nn.Sequential(
    nn.Linear(num_features, 1024),
    nn.ReLU(),
    nn.Linear(1024, 1024),
    nn.ReLU(),
    nn.Linear(1024, 20)
)

# Load model weights from checkpoint (after model creation)
_checkpoint_path = \
    "PASCAL-VOC/checkpoints-efn_v2/final_model.pth"
try:
    _checkpoint_obj = tt.load(_checkpoint_path, map_location="cpu")
    model.load_state_dict(_checkpoint_obj["model_state_dict"])
    # print(f"Loaded checkpoint: {_checkpoint_path}")
except Exception as _load_exc:
    print(f"Warning: failed to load checkpoint from {_checkpoint_path}: {_load_exc}")
    
model.transforms = T.Compose([
    T.Resize((300, 300), interpolation=T.InterpolationMode.BICUBIC, antialias=True),
    T.CenterCrop((300, 300)),
    T.ToTensor(),
    T.Normalize(mean=[0.500, 0.500, 0.500], std=[0.500, 0.500, 0.500])
])
model.eval()
device = None
if platform.uname().system == "Darwin":
    device = tt.device("mps")
    model.to("mps")
else:
    device = tt.device("cuda")
    model.to("cuda")


def preprocess(path, shape, device, mode = None) -> Data:
    img = Image.open(path).convert("RGB")
    data = Data(img, shape, device, mode='RGB')
    data.data = model.transforms(img).unsqueeze(0).to(device)  # type: ignore
    # original = Image.open(path).convert("RGB")
    # original = T.functional.center_crop(original, (224, 224))
    # original = T.functional.resize(original, (232, 232))
    # data.input = original
    
    return data


def prediction_function(mutants, masks_objects=None, 
                        target=None, raw=False, binary_threshold=None):
    with tt.no_grad():
        tensor = model(mutants.to(device))
        if raw:
            return F.softmax(tensor, dim=1)
        return from_pytorch_tensor(tensor, target=target)


def model_shape():
    return ["N", 3, 300, 300]

# model_shape = ("N", 3, 232, 232)


def plot_batch_images(batch_tensor, figsize=(12, 12), cmap='viridis', title=None):
    """
    Plots a batch of images in a single figure using subplots.

    Args:
        batch_tensor: A tensor or NumPy array of images.
                      Expected shapes:
                      - (N, H, W) for grayscale images
                      - (N, H, W, C) for color images (C=3 for RGB, C=4 for RGBA)
                      - (N, C, H, W) for color or grayscale (PyTorch convention)
                      N is the number of images in the batch.
        figsize (tuple): The figure size for plt.subplots.
        cmap (str): Colormap to use for grayscale images. Ignored for RGB/RGBA.
        title (str, optional): Overall title for the figure.
    """
    # Handle PyTorch tensors: move to CPU, detach, convert to NumPy
    batch_tensor = batch_tensor.cpu()

    if hasattr(batch_tensor, 'detach'):
        batch_tensor = batch_tensor.detach()
    if hasattr(batch_tensor, 'numpy'):
        batch_tensor = batch_tensor.numpy()

    if not isinstance(batch_tensor, np.ndarray):
        raise TypeError(f"batch_tensor must be a NumPy array or a PyTorch tensor, got {type(batch_tensor)}")

    if batch_tensor.ndim < 3: # Expecting at least (N, H, W)
        raise ValueError(f"batch_tensor must have at least 3 dimensions (N, H, W), got shape {batch_tensor.shape}")

    num_images = batch_tensor.shape[0]

    if num_images == 0:
        print("No images to plot in the batch.")
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(0.5, 0.5, "No images in batch", ha='center', va='center', fontsize=12)
        ax.axis('off')
        if title:
            fig.suptitle(title)
        plt.show()
        return fig, ax

    # Determine grid dimensions for subplots
    grid_cols = int(np.ceil(np.sqrt(num_images)))
    grid_rows = int(np.ceil(num_images / float(grid_cols)))

    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=figsize)
    
    # Flatten axes array for easy iteration, handling single image case
    axes_list = np.atleast_1d(axes).flatten()

    for i in range(num_images):
        ax = axes_list[i]
        img = batch_tensor[i].copy() # Work on a copy

        # Heuristic to adjust image dimensions if C,H,W format
        # Assumes C is the first dimension if it's 1, 3, or 4, and H, W are larger than C
        if img.ndim == 3:
            is_chw = False
            # Case 1: (C, H, W) where C is 1, 3, or 4
            if img.shape[0] in [1, 3, 4]:
                # If H and W are significantly larger than C, assume C,H,W
                if img.shape[1] > img.shape[0] and img.shape[2] > img.shape[0] and img.shape[1]>4 and img.shape[2]>4: # Basic heuristic
                    is_chw = True
                # If H or W is same as C, it's ambiguous. e.g. (3,3,224) vs (3,224,3)
                # This heuristic might need refinement for oddly shaped tensors.
                # For now, if shape[0] is a typical channel number and others are larger, assume CHW.

            if is_chw:
                 img = np.transpose(img, (1, 2, 0)) # Convert C,H,W to H,W,C
        
        # If image is (H,W,1), squeeze to (H,W) for grayscale cmap
        if img.ndim == 3 and img.shape[2] == 1:
            img = img.squeeze(axis=2)

        # Normalize float images to [0, 1] range for display
        # if img.dtype in [np.float16, np.float32, np.float64]:
        #     img_min, img_max = np.min(img), np.max(img)
        #     if img_min != img_max:
        #         img = (img - img_min) / (img_max - img_min)
            # else: constant image, imshow will handle it.

        # Plot the image
        current_cmap = cmap
        if img.ndim == 2:  # Grayscale image
            pass # cmap is already set
        elif img.ndim == 3:  # RGB or RGBA image
            current_cmap = None # imshow handles RGB/RGBA without cmap
        else:
            ax.text(0.5, 0.5, f"Unsupported\nimg shape\n{img.shape}", ha='center', va='center')
            img = None # Skip plotting

        if img is not None:
            ax.imshow(img, cmap=current_cmap, interpolation='nearest')
        
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f"Image {i+1}", fontsize=10)

    # Hide any unused subplots
    for j in range(num_images, len(axes_list)):
        axes_list[j].axis('off')

    if title:
        fig.suptitle(title, fontsize=16)

    # Adjust layout
    try:
        if title:
            fig.tight_layout(rect=[0, 0.03, 1, 0.95])
        else:
            fig.tight_layout()
    except ValueError:
        # tight_layout can sometimes fail with very specific subplot arrangements
        # or if the figure window is too small.
        # Fallback or print a warning.
        print("Warning: tight_layout failed. Layout may not be optimal.")


    plt.show()
    return fig, axes_list
