import matplotlib.pyplot as plt
import numpy as np
import math
from mpl_toolkits.mplot3d import Axes3D
import torch
import torch.nn as nn
import sys
import random
from pathlib import Path
import argparse
import glob
from scipy import signal
import json
import os.path
import cv2
from PIL import Image
import subprocess
from math import pi, sqrt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.cbook import get_sample_data
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from scipy.signal import find_peaks
import scipy.ndimage as ndimage
import shutil
from scipy.spatial import ConvexHull, convex_hull_plot_2d
from sklearn.decomposition import PCA
import warnings
warnings.filterwarnings("ignore") 
from transformers import logging
logging.set_verbosity_error()


def Plot3D(P):
    fig = plt.figure()
    # ax = Axes3D(fig)
    ax = Axes3D(fig, auto_add_to_figure=False)
    fig.add_axes(ax)
    ax.scatter(P[:, 0], P[:, 1], P[:, 2], c='r', marker='o')
    for i in range(13):
        ax.text(P[i, 0], P[i, 1], P[i, 2], '%s' % (str(i)), size=10, zorder=1, color='k')
    
    left  = [[9,11],[7,9],[1,3],[3,5]] # bones on the left
    right = [[0,2],[2,4],[8,10],[6,8]] # bones on the right
    right += [[4,5],[10,11]] # bones on the torso
    head = 12
    pairs = left + right
    for p in pairs:
        ax.plot(P[p][:, 0], P[p][:, 1], P[p][:, 2], 'k-')
    ax.set_xlabel('x'); ax.set_ylabel('y'); ax.set_zlabel('z')
    ax.view_init(90, -90)
    m = np.mean(P, axis=0)
    ax.set_xlim3d(m[0]-1, m[0]+1); ax.set_ylim3d(m[1]-1, m[1]+1); ax.set_zlim3d(m[2]+1, m[2]-1)
    plt.show()

def norm(v):
    return v / np.linalg.norm(v) 

def anglefromvec(v1,v2):
    v1 = norm(v1) 
    v2 = norm(v2) 
    vec = np.arccos(np.clip(np.dot(v1, v2), -1.0, 1.0)) / np.pi * 180
    return vec if not np.isnan(vec) else 0

def smooth(y, box_pts):
    box = np.ones(box_pts)/box_pts
    y_smooth = np.convolve(y, box, mode='same')
    return y_smooth

def proj(x,y):
    vec = np.dot(x, y) / np.linalg.norm(y)
    return vec if not np.isnan(vec) else 0

def avgpose2d(pose2d, a,b): # return the coordinate of the middle of joint of index a and b
    return (pose2d[:,a,:]+pose2d[:,b,:])/2.0

def calc_elements(skel):
    ''' Calculate 8 dance elements
        Args:
        skel (array): Array of 3D skeleton for a certain video.
        out (array): Calculated 8 dance elements
    '''
    midshoulder = avgpose2d(skel, 10, 11)[:,None,:] # 13
    midhip = avgpose2d(skel, 4, 5)[:,None,:] # 14
    skel = np.concatenate([skel, midshoulder, midhip], axis = 1)
    num_frame = len(skel)
    end_joints = [6, 8, 10, 7, 9, 11, 0, 2, 4, 1, 3, 5, 12]
    vectors = [[[7, 9], [11, 9]], [[9, 11], [10, 11]], # left lower/upper arm
               [[6, 8], [10, 8]], [[8, 10], [11, 10]], # right lower/upper arm
               [[1, 3], [5, 3]], [[3, 5], [4, 5]], # left lower/upper leg
               [[0, 2], [4, 2]], [[2, 4], [5, 4]], # right lower/upper leg
               [[12, 13], [14, 13]], [[13, 14], [[2, 3], 14]], # head; torso
               [[11, 13], [14, 13]], [[10, 13], [14, 13]], # left shoulder; right shoulder
               [[5, 14], [13, 14]], [[4, 14], [13, 14]]] # left hip; right hip
    
    torso_axis_x = np.array([norm(skel[i][4]-skel[i][5]) for i in range(num_frame)])
    torso_axis_y = np.array([norm(skel[i][13] -skel[i][14] ) for i in range(num_frame)])
    torso_axis_z = np.array([norm(np.cross(torso_axis_y[i], torso_axis_x[i])) for i in range(num_frame)])
    origins = [torso_axis_x, torso_axis_y, torso_axis_z]

    vel = skel[1:] - skel[:-1]
    vel = np.vstack([vel[0].reshape(1, 15, 3),vel])
    acl = vel[1:] - vel[:-1]
    acl = np.vstack([acl[0].reshape(1, 15, 3),acl])

    elements = []

    spaceEff = space_eff(origins, num_frame)
    
    weightEffort = weight_eff(vel[:,end_joints])

    timeEffort = time_eff(acl[:,end_joints])

    flowEffort = flow_eff(acl[:,end_joints])

    shapeShaping = shape_shaping(skel)

    shapeDirectional = shape_directional(vel, skel, end_joints, num_frame)

    shapeFlow = shape_flow(skel)

    Body = body(skel, vectors)
    
    elements = torch.tensor(np.vstack([spaceEff, weightEffort, timeEffort, flowEffort,
                         shapeShaping, shapeDirectional, shapeFlow, Body]))
    return elements


def space_eff(origins, num):
    theta_x = smooth(list(map(proj, origins[0], np.tile(np.array([-1, 0, 0]), (num, 1)))),5)
    theta_y = smooth(list(map(proj, origins[2], np.tile(np.array([ 0, 0, 1]), (num, 1)))),5)
    return np.array([theta_x, theta_y])

def weight_eff(vel):
    return np.sum(vel**2, (1,2))

def time_eff(acl):
    return np.sum(abs(acl), (1,2))

def flow_eff(acl):
    flow = acl[1:] - acl[:-1]
    flow = np.vstack([flow[0].reshape(1, -1, 3),flow])
    return np.sum(abs(flow), (1,2))

def shape_shaping(skel):
    vertical_points = skel[:,:,(0,1)]
    horizontal_points = skel[:,:,(0,2)]
    sagittal_points = skel[:,:,(1,2)]
    vertical_volume   = [ConvexHull(v).volume if len(set(v.reshape(-1))) != 1 else 0 for v in vertical_points]
    horizontal_volume = [ConvexHull(h).volume if len(set(h.reshape(-1))) != 1 else 0 for h in horizontal_points]
    sagittal_volume   = [ConvexHull(s).volume if len(set(s.reshape(-1))) != 1 else 0 for s in sagittal_points]
    return np.vstack([vertical_volume, horizontal_volume, sagittal_volume])

def shape_directional(vel, skel, end_joints, num):
    curs = []
    for j in end_joints:
        pca = PCA(n_components=2)
        pca.fit(vel[:,j])
        components_ = pca.components_
        mean_ = pca.mean_
        xy = [mean_+components_[0], mean_+components_[1]]
        xy = norm(xy)

        a_x = list(map(proj, skel[:,j], np.tile(xy[0], (num, 1))))
        a_y = list(map(proj, skel[:,j], np.tile(xy[1], (num, 1))))

        dx_dt = np.gradient(a_x)
        dy_dt = np.gradient(a_y)
        d2x_dt2 = np.gradient(dx_dt)
        d2y_dt2 = np.gradient(dy_dt)

        curvature = np.abs(d2x_dt2 * dy_dt - dx_dt * d2y_dt2) / (dx_dt * dx_dt + dy_dt * dy_dt)**1.5
        curs.append(curvature)
    
    curs = np.vstack(curs)
    curs[np.isnan(curs)] = 0
    return curs
    
    # import matplotlib.pyplot as plt
    # from mpl_toolkits.mplot3d import Axes3D
    # fig = plt.figure()
    # ax = fig.add_subplot(111, projection='3d')
    # ax.plot([0, 0], [0, 0], [0, 1], color='green')
    # ax.plot([0, 0], [0, 1], [1, 1], color='green')
    # ax.plot([0, 0], [1, 1], [1, 2], color='green')
    # ax.plot([0, 0], [1, 2], [2, 1], color='green')
    # ax.set_xlabel('X axis')
    # ax.set_ylabel('Y axis')
    # ax.set_zlabel('Z axis')

    # a = np.array([[0,0,0],[0,0,1],[0,1,1],[0,1,2],[0,2,1]])
    # dev_a = np.array([[0,0,1],[0,1,0],[0,0,1],[0,1,-1]])
    # pca = PCA(n_components=2)
    # pca.fit(dev_a)
    # c = pca.components_
    # m = pca.mean_
    # ax.plot([m[0], m[0]+c[0][0]], [m[1], m[1]+c[0][1]], [m[2], m[2]+c[0][2]])
    # ax.plot([m[0], m[0]+c[1][0]], [m[1], m[1]+c[1][1]], [m[2], m[2]+c[1][2]])

def shape_flow(skel):
    volume = np.array([ConvexHull(s).volume if len(set(s.reshape(-1))) != 1 else 0 for s in skel])
    return volume

def body(skel, vectors):
    ang_list = []
    for (v1, v2) in vectors:
        if isinstance(v2[0], int):
            ang = np.array(list(map(anglefromvec, skel[:,v2[0]]-skel[:,v2[1]], skel[:,v1[0]]-skel[:,v1[1]])))
        else:
            ang = np.array(list(map(anglefromvec, 0.5*(skel[:,v2[0][0]] + skel[:,v2[0][1]])-skel[:,v2[1]], skel[:,v1[0]]-skel[:,v1[1]])))
        ang_list.append(ang)
    return np.vstack(ang_list)
    

root_dir = '../Data'
files = os.listdir(root_dir + '/Skeleton')
save_path = root_dir + '/Elements'

for i, file in enumerate(sorted(files)):

    skel_raw = np.load(root_dir + '/Skeleton/' + file, allow_pickle=True).item()
    elements_list = []
    for person in skel_raw.values():
        skel = np.array(list(person.values())).reshape(-1, 13, 3)
        elements = calc_elements(skel)
        elements_list.append(elements)

    vid = file.split('/')[-1][:-4]
    torch.save(torch.mean(torch.stack(elements_list), dim=0), os.path.join(save_path, vid))
