#!/usr/bin/env python3
import platform


import torch 
import torch.nn.functional as F
from PIL import Image  # type: ignore
from rex_xai.input.input_data import Data
from rex_xai.responsibility.prediction import from_pytorch_tensor
import matplotlib.pyplot as plt
from torchvision.models import ResNet50_Weights, get_model
import numpy as np
import torchvision.transforms as T
import torch.nn as nn


model = get_model('resnet50', weights="DEFAULT")
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 512),
    nn.ReLU(),
    nn.Linear(512, 257)
)

# Load model weights from checkpoint (after model creation)
_checkpoint_path = \
    "CalTech-256/checkpoints_two_layer/best_model_epoch_24_acc_82.84.pth"
try:
    _checkpoint_obj = torch.load(_checkpoint_path, map_location="cpu")
    model.load_state_dict(_checkpoint_obj["model_state_dict"])
    # print(f"Loaded checkpoint: {_checkpoint_path}")
except Exception as _load_exc:
    print(f"Warning: failed to load checkpoint from {_checkpoint_path}: {_load_exc}")
    
model.transforms = T.Compose([
    T.CenterCrop((224, 224)),
    T.Resize((232, 232)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
model.eval()
device = None
if platform.uname().system == "Darwin":
    device = torch.device("mps")
    model.to("mps")
else:
    device = torch.device("cuda:0")
    model.to("cuda:0")


def preprocess(path, shape, device, mode = None) -> Data:
    img = Image.open(path).convert("RGB")
    data = Data(img, shape, device, mode='RGB')
    data.data = model.transforms(img).unsqueeze(0).to(device)  # type: ignore
    # data.data = torch.from_numpy(np.load(path.split('.')[0]+'.npy')).unsqueeze(0).to(device)  # type: ignore
    original = Image.open(path).convert("RGB")
    original = T.functional.center_crop(original, (224, 224))
    original = T.functional.resize(original, (232, 232))
    data.input = original
    
    return data


def prediction_function(mutants, masks_objects=None, 
                        target=None, raw=False, binary_threshold=None):
    with torch.no_grad():
        tensor = model(mutants.to(device))
        if raw:
            return F.softmax(tensor, dim=1)
        return from_pytorch_tensor(tensor, target=target)


def model_shape():
    return ["N", 3, 232, 232]
#    return ["N", 3, 464, 464] 

model_shape = ("N", 3, 232, 232)

