#!/usr/bin/env python3
import platform
import sys
sys.path.append('scripts')

import torch
import torch.nn.functional as F
from PIL import Image  # type: ignore
from rex_xai.input.input_data import Data
from rex_xai.responsibility.prediction import from_pytorch_tensor
import matplotlib.pyplot as plt
from spectral_conv import ConvNet
import numpy as np
import torchvision.transforms as T

model = ConvNet()
model.load_state_dict(torch.load('ECG-Spectral/Spectral/three_class_conv/threeclass_combined_DNA_model.pt'))
model.eval()
device = None
if platform.uname().system == "Darwin":
    device = torch.device("mps")
    model.to(device)
else:
    device = torch.device("cuda:1")
    model.to(device)


def preprocess(path, shape, device, mode) -> Data:

    spectra = torch.from_numpy(np.load(path)).float().unsqueeze(0)
    data = Data(spectra, shape, device, mode='spectral')
    data.data = spectra.unsqueeze(0).to(device)  # type: ignore
    return data


def prediction_function(mutants, masks_objects=None, 
                        target=None, raw=False, binary_threshold=None):
    with torch.no_grad():
        tensor = model(mutants.to(device))
        if raw:
            return F.softmax(tensor, dim=1)
        return from_pytorch_tensor(tensor, target=target)


def model_shape():
    return ["N", 1, 1356]

