import os
import sys
import tempfile

import cv2
import torch
import torchvision.transforms as transforms
import torchvision.utils as tu
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(project_root)
from loaders.shapes import ShapesDataset
classifier_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'shape'))
sys.path.append(classifier_root)
from classifier import ShapeClassifier
import run as shapes

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def specs() -> dict:
    ds = ShapesDataset('./shape_dataset.pickle')

    model_path = shapes.last_snapshot_for()
    if not model_path:
        raise Exception(f'No shape classifier model found in {shapes.output_home}')
    print(f"Shape classifier will load from {model_path}")

    m = ShapeClassifier(ds.class_count())
    m.load_state_dict(torch.load(model_path, map_location=device)['model'])
    m.eval()

    return {
            'model': m.to(device),
            'interpretation': ds.label_map,
            }

def run(model, minput:torch.Tensor) -> list:
    with torch.no_grad():
        predictions = model(minput.unsqueeze(0).to(device))
        class_idx = torch.argmax(predictions).item()
    return class_idx

if __name__ == '__main__':
    print('Run for quick testing purpose...')
    img = Image.open(sys.argv[2])
    to_tensor = transforms.Compose([transforms.ToTensor()])
    results = run(sys.argv[1], to_tensor(img).unsqueeze(0))
    print(results)
