#!/usr/bin/env python3
import sys
# sys.path.append('/home/ubuntu/akchunya/Activation-Deactitvation/scripts')
# sys.path.append('/Users/akchunya/Projects/activation_deactivation_masking/scripts')
sys.path.append('scripts')
import platform
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image  # type: ignore
from rex_xai.input.input_data import Data
from rex_xai.responsibility.prediction import from_pytorch_tensor
import matplotlib.pyplot as plt
from scripts.archive.interpretable_resnet_spectral import InterpretableResNet
import numpy as np

model = InterpretableResNet()
model.load_state_dict(torch.load('ECG-Spectral/Spectral/two_class_resnet/resnet_binary_invitro.pt'))
model.eval()

device = None
ACTUAL_INPUT = None

if platform.uname().system == "Darwin":
    device = torch.device("mps")
    model.to("mps")
else:
    device = torch.device("cuda")
    model.to("cuda")


def preprocess(path, shape, device, mode) -> Data:
    global ACTUAL_INPUT
    spectra = torch.from_numpy(np.load(path)).float().unsqueeze(0)
    data = Data(spectra, shape, device, mode='spectral')
    data.data = spectra.unsqueeze(0).to(device)  # type: ignore
    ACTUAL_INPUT = data.data
    return data


def prediction_function(mutants, masks_objects=None, 
                        target=None, raw=False, binary_threshold=None):
    with torch.no_grad():
        binary_masks = torch.where(ACTUAL_INPUT != mutants, 0.0, 1.0)
        tensor = model(ACTUAL_INPUT, explanation_mask=binary_masks,
                    explanation_mode=True)
        if raw:
            return F.softmax(tensor, dim=1)
        
        return from_pytorch_tensor(tensor, target=target)


def model_shape():
    return ["N", 1, 1356]
