import os
import sys
import tempfile
from typing import Union

import numpy
from PIL import Image
import torch
import torchvision.utils as tu

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(project_root)
from loaders.ravdess_stills import RAVDESSStillsDataset
classifier_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'ec'))
sys.path.append(classifier_root)
from classifier import EmotionClassifier
import run as ec

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def specs() -> dict:
    ds = RAVDESSStillsDataset('/tmp/ravdess_stills_dataset_precompute.bin')

    model_path = ec.last_snapshot_for()
    if not model_path:
        raise Exception(f'No EC classifier model found in {ec.output_home}')
    print(f"Emotion classifier will load from {model_path}")

    m = EmotionClassifier(ds.class_count())
    m.load_state_dict(torch.load(model_path, map_location=device)['model'])
    m.eval()

    return {
            'model': m.to(device),
            'interpretation': ds.emotions,
            'preprocessing': preprocess,
            }


def preprocess(t:torch.Tensor) -> torch.Tensor:
    '''
    Sorry, the device round robin is really painful. TODO
    '''
    img = Image.fromarray(t.cpu().numpy().astype(numpy.uint8).transpose((1,2,0)), 'RGB').resize((320, 180))
    return torch.as_tensor(numpy.array(img).transpose((2, 0, 1)), dtype=torch.float32).to(t.device)


def run(model, minput:Union[torch.Tensor, numpy.array]) -> list:
    with torch.no_grad():
        predictions = model(torch.as_tensor(minput / 255.0, dtype=torch.float32).unsqueeze(0).to(device))
        class_idx = torch.argmax(predictions).item()
    return class_idx

if __name__ == '__main__':
    print('Run for quick testing purpose...')
    from PIL import Image
    img = Image.open(sys.argv[1])
    s = specs()
    result = run(s['model'], numpy.array(img).transpose((2, 0, 1)))
    print(s['interpretation'][result])
