from eval.unconstrained.models.stgcn import STGCN
import pandas as pd
import os.path as osp
import os
import datetime

import torch

from torch.utils.data import DataLoader
import numpy as np
import sys as _sys
from eval.a2m.action2motion.fid import calculate_fid
from eval.a2m.action2motion.diversity import calculate_diversity
from eval.unconstrained.metrics.kid import calculate_kid
from eval.unconstrained.metrics.precision_recall import precision_and_recall
from matplotlib import pyplot as plt

TEST = False


def initialize_model(device, modelpath):
    num_classes = 12
    model = STGCN(in_channels=3,
                  num_class=num_classes,
                  graph_args={"layout": 'openpose', "strategy": "spatial"},
                  edge_importance_weighting=True,
                  device=device)
    model = model.to(device)
    state_dict = torch.load(modelpath, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()
    return model

def calculate_activation_statistics(activations):
    activations = activations.cpu().detach().numpy()
    mu = np.mean(activations, axis=0)
    sigma = np.cov(activations, rowvar=False)
    return mu, sigma


def compute_features(model, iterator, device):
    activations = []
    predictions = []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            batch_for_model = {}
            batch_for_model['x'] = batch.to(device).float()
            model(batch_for_model)
            activations.append(batch_for_model['features'])
            predictions.append(batch_for_model['yhat'])
            # labels.append(batch_for_model['y'])
        activations = torch.cat(activations, dim=0)
        predictions = torch.cat(predictions, dim=0)
    return activations, predictions


def evaluate_unconstrained_metrics(generated_motions, device, fast):

    act_rec_model_path = './assets/actionrecognition/humanact12_gru_modi_struct.pth.tar'
    dataset_path = './dataset/HumanAct12Poses/humanact12_unconstrained_modi_struct.npy'

    # initialize model
    act_rec_model = initialize_model(device, act_rec_model_path)

    generated_motions -= generated_motions[:, 8:9, :, :]  # locate root joint of all frames at origin

    iterator_generated = DataLoader(generated_motions, batch_size=64, shuffle=False, num_workers=8)

    # compute features of generated motions
    generated_features, generated_predictions = compute_features(act_rec_model, iterator_generated, device=device)
    generated_stats = calculate_activation_statistics(generated_features)


    # dataset motions
    motion_data_raw = np.load(dataset_path, allow_pickle=True)
    motion_data = motion_data_raw[:, :15]  # data has 16 joints for back compitability with older formats
    motion_data -= motion_data[:, 8:9, :, :]  # locate root joint of all frames at origin
    iterator_dataset = DataLoader(motion_data, batch_size=64, shuffle=False, num_workers=8)

    # compute features of dataset motions
    dataset_features, dataset_predictions = compute_features(act_rec_model, iterator_dataset, device=device)
    real_stats = calculate_activation_statistics(dataset_features)

    print("evaluation resutls:\n")

    fid = calculate_fid(generated_stats, real_stats)
    print(f"FID score: {fid}\n")

    print("calculating KID...")
    kid = calculate_kid(dataset_features.cpu(), generated_features.cpu())
    (m, s) = kid
    print('KID : %.3f (%.3f)\n' % (m, s))

    dataset_diversity = calculate_diversity(dataset_features)
    generated_diversity = calculate_diversity(generated_features)
    print(f"Diversity of generated motions: {generated_diversity}")
    print(f"Diversity of dataset motions: {dataset_diversity}\n")

    if fast:
        print("Skipping precision-recall calculation\n")
        precision = recall = None
    else:
        print("calculating precision recall...")
        precision, recall = precision_and_recall(generated_features, dataset_features)
        print(f"precision: {precision}")
        print(f"recall: {recall}\n")

    metrics = {'fid': fid, 'kid': kid[0], 'diversity_gen': generated_diversity.cpu().item(), 'diversity_gt':  dataset_diversity.cpu().item(),
                 'precision': precision, 'recall':recall}
    return metrics

