import os
import sys
import tempfile

import cv2
import numpy
from PIL import Image
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
import torchvision.utils as tu

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'demos', 'emotion_reading_baseline_cnn', 'PyTorch'))
sys.path.append(project_root)
from model import *


def specs() -> dict:
    model_path = os.path.join(project_root, 'models', 'FER_trained_model.pt')
    m = Face_Emotion_CNN()
    m.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage), strict=False)
    m.eval()

    labels = {
            0: 'neutral',
            1: 'happiness',
            2: 'surprise',
            3: 'sadness',
            4: 'anger',
            5: 'disgust',
            6: 'fear'
            }

    return {
            'model': m,
            'input': (1, 48, 48),
            'output': (1,),
            'interpretation': labels,
            }


def run(model, minput:torch.Tensor) -> list:
    '''
    The input should contain a single face.

    minput is CxHxW
    '''
    img = minput.cpu().numpy().transpose((1,2,0)).astype(numpy.uint8) # To meet CV requirements.

    classifier_path = os.path.join(project_root, 'models', 'haarcascade_frontalface_default.xml')
    face_finder = cv2.CascadeClassifier(classifier_path)
    try:
        x, y, w, h = face_finder.detectMultiScale(img)[0]
    except:
        x = max([0, (img.shape[1] - 48) // 2])
        y = max([0, (img.shape[0] - 48) // 2])
        w = 49
        h = 49

    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    resized = cv2.resize(img_gray[y:y + h, x:x + w], (48, 48))
    norm_input_img = Image.fromarray((resized))
    norm_input = torch.as_tensor(numpy.expand_dims(numpy.array(norm_input_img), axis=0), dtype=torch.float32).unsqueeze(0)
    with torch.no_grad():
        log_ps = model.cpu()(norm_input)
        ps = torch.exp(log_ps)
        top_p, top_class = ps.topk(1, dim=1)
        result = int(top_class.numpy())
    norm_input_img.close()

    return result

if __name__ == '__main__':
    print('Run for quick testing purpose...')
    plug = specs()
    print(plug)
    img = Image.open(sys.argv[1])
    t = torch.as_tensor(numpy.array(img).transpose((2, 0, 1)), dtype=torch.float32)
    results = run(plug['model'], t)
    print(results)
