import os
import sys

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from skimage import io
from skimage.transform import resize

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'demos', 'emotion_reading_another_cnn'))
sys.path.append(project_root)
from models import *
import transforms as transforms


def specs() -> dict:
    net = VGG('VGG19')
    checkpoint = torch.load(os.path.join(project_root, 'FER2013_VGG19', 'PrivateTest_model.t7'),
            map_location=torch.device('cpu'))
    net.load_state_dict(checkpoint['net'])
    net.eval()

    labels = {
            0: 'anger',
            1: 'disgust',
            2: 'fear',
            3: 'happiness',
            4: 'sadness',
            5: 'surprise',
            6: 'neutral',
            }

    return {
            'model': net,
            'input': (1, 48, 48),
            'output': (1,),
            'interpretation': labels,
            }


def rgb2gray(rgb):
    return numpy.dot(rgb[...,:3], [0.299, 0.587, 0.114])


def run(model, minput:torch.Tensor) -> list:
    '''
    The input should contain a single face.
    '''
    transform_test = transforms.Compose([
        transforms.TenCrop(44),
        transforms.Lambda(lambda crops: torch.stack([torch.as_tensor(numpy.array(crop).transpose((2, 0, 1)), dtype=torch.float32) for crop in crops])),
        ])

    gray = rgb2gray(minput.cpu().numpy().astype(numpy.uint8).transpose((1, 2, 0)))
    gray = resize(gray, (48,48), mode='symmetric').astype(numpy.uint8)

    img = gray[:, :, numpy.newaxis]
    img = numpy.concatenate((img, img, img), axis=2)
    img = Image.fromarray(img)
    inputs = transform_test(img)

    net = specs()['model']

    ncrops, c, h, w = numpy.shape(inputs)

    inputs = inputs.view(-1, c, h, w)
    inputs = Variable(inputs, volatile=True)
    outputs = net(inputs)

    outputs_avg = outputs.view(ncrops, -1).mean(0) # avg over crops

    score = F.softmax(outputs_avg)
    _, predicted = torch.max(outputs_avg.data, 0)

    return int(predicted.numpy())

if __name__ == '__main__':
    print('Run for quick testing purpose...')
    plug = specs()
    print(plug)
    img = Image.open(sys.argv[1])
    t = torch.as_tensor(numpy.array(img).transpose((2, 0, 1)), dtype=torch.float32)
    results = run(plug['model'], t)
    print(f"Class: {results}")
    print(f"Inter: {plug['interpretation'][results]}")
