import matplotlib.pyplot as plt
import numpy as np
import math
from mpl_toolkits.mplot3d import Axes3D
import torch
import torch.nn as nn
import sys
import random
from pathlib import Path
import argparse
import glob
from scipy import signal
import json
import os.path
import cv2
from PIL import Image
import subprocess
from math import pi, sqrt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.cbook import get_sample_data
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from scipy.signal import find_peaks
import scipy.ndimage as ndimage
import shutil

def Plot3D(P):
    fig = plt.figure()
    # ax = Axes3D(fig)
    ax = Axes3D(fig, auto_add_to_figure=False)
    fig.add_axes(ax)
    ax.scatter(P[:, 0], P[:, 1], P[:, 2], c='r', marker='o')
    for i in range(13):
        ax.text(P[i, 0], P[i, 1], P[i, 2], '%s' % (str(i)), size=10, zorder=1, color='k')
    
    left  = [[9,11],[7,9],[1,3],[3,5]] # bones on the left
    right = [[0,2],[2,4],[8,10],[6,8]] # bones on the right
    right += [[4,5],[10,11]] # bones on the torso
    head = 12
    pairs = left + right
    for p in pairs:
        ax.plot(P[p][:, 0], P[p][:, 1], P[p][:, 2], 'k-')
    ax.set_xlabel('x'); ax.set_ylabel('y'); ax.set_zlabel('z')
    ax.view_init(90, -90)
    m = np.mean(P, axis=0)
    ax.set_xlim3d(m[0]-1, m[0]+1); ax.set_ylim3d(m[1]-1, m[1]+1); ax.set_zlim3d(m[2]+1, m[2]-1)
    plt.show()


def norm(v):
    return v / np.linalg.norm(v) 

def anglefromvec(v1,v2):
    v1 = norm(v1) 
    v2 = norm(v2) 
    vec = np.arccos(np.clip(np.dot(v1, v2), -1.0, 1.0)) / np.pi * 180
    return vec if not np.isnan(vec) else 0

def smooth(y, box_pts):
    box = np.ones(box_pts)/box_pts
    y_smooth = np.convolve(y, box, mode='same')
    return y_smooth

def proj(x,y):
    return np.dot(x, y) / np.linalg.norm(y)

def draw_text(img, text, uv_top_left, fontScale=0.5, thickness=1, 
              fontFace=cv2.FONT_HERSHEY_SIMPLEX, line_spacing=1.5):
    """
    Draws multiline with an outline.
    """
    uv_top_left = np.array(uv_top_left, dtype=float)
    color=[(0, 0, 255), (255, 0, 0)]

    x, y, w, h = int(uv_top_left[0]), int(uv_top_left[1]), 500, int(22 * 24 * line_spacing)
    
    sub_img = img[y:y+h, x:x+w]
    white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255
    res = cv2.addWeighted(sub_img, 0.5, white_rect, 0.5, 1.0)
    img[y:y+h, x:x+w] = res

    sub_img = img[y:y+h, 1400:1400+w]
    white_rect = np.ones(sub_img.shape, dtype=np.uint8) * 255
    res = cv2.addWeighted(sub_img, 0.5, white_rect, 0.5, 1.0)
    img[y:y+h, 1400:1400+w] = res

    for i, line in enumerate(text.splitlines()):
        (w, h), _ = cv2.getTextSize(text=line, fontFace=fontFace, fontScale=fontScale, thickness=thickness)
        uv_bottom_left_i = uv_top_left + [0, h]
        org = tuple(uv_bottom_left_i.astype(int))

        cv2.putText(img, text=line, org=org, fontFace=fontFace, fontScale=fontScale, 
            color=color[(i//4)%2], thickness=thickness, lineType=cv2.LINE_4)

        uv_top_left += [0, h * line_spacing]
        if i == 23:
            uv_top_left = np.array((1400, 50), dtype=float)

def write2video(file, strings):
    cap = cv2.VideoCapture(root_dir + '/RGB_VIDEO_v2/' + file[:-4] + '_color.avi')
    outpath = './outputs_movement/frames/'
    if os.path.exists(outpath) and os.path.isdir(outpath):
        shutil.rmtree(outpath)
        os.mkdir(outpath)
    counter = 0
    while(True):
        
        # Capture frames in the video
        ret, frame = cap.read()
        if not ret:
            break
    
        font = cv2.FONT_HERSHEY_SIMPLEX
    
        draw_text(frame, strings[counter], (50, 50), fontScale=1,
        thickness=2, fontFace=cv2.FONT_HERSHEY_SIMPLEX, line_spacing=1.5)
    
        # Display the resulting frame
        cv2.imwrite('./outputs_movement/frames/%05d.png' % counter, frame)
        counter += 1
    
    # release the cap object
    cap.release()
    # close all windows
    cv2.destroyAllWindows()

def write2videoSkel(file, strings):
    cap = cv2.VideoCapture(root_dir + '/RGB_VIDEO_v2/' + file[:-4] + '_color.avi')
    outpath = './outputs_movement/frames_ske/'
    if os.path.exists(outpath) and os.path.isdir(outpath):
        shutil.rmtree(outpath)
        os.mkdir(outpath)
    counter = 0
    while(True):
        
        # Capture frames in the video
        ret, frame = cap.read()
        if not ret:
            break
    
        font = cv2.FONT_HERSHEY_SIMPLEX
    
        draw_text(frame, strings[counter], (50, 50), fontScale=1,
        thickness=2, fontFace=cv2.FONT_HERSHEY_SIMPLEX, line_spacing=1.5)

        # Plot skeleton
        fig = plt.figure(figsize=(30, 15))
        ax = fig.add_subplot(1, 2, 1, projection='3d')
        fig.add_axes(ax)
        P =skel[counter]
        ax.scatter(P[:, 0], P[:, 1], P[:, 2], c='r', marker='o')
        for i in range(17):
            ax.text(P[i, 0], P[i, 1], P[i, 2], '%s' % (str(i)), size=10, zorder=1, color='k')
        pairs = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7], [7, 8],
             [8, 9], [9, 10], [8, 11], [11, 12], [12, 13], 
             [8, 14], [14, 15], [15, 16]]
        for p in pairs:
            ax.plot(P[p][:, 0], P[p][:, 1], P[p][:, 2], 'k-')
        ax.set_xlabel('x'); ax.set_ylabel('y'); ax.set_zlabel('z')
        ax.view_init(90, -90)
        m = np.mean(P, axis=0)
        ax.set_xlim3d(m[0]-1, m[0]+1); ax.set_ylim3d(m[1]-1, m[1]+1); ax.set_zlim3d(m[2]+1, m[2]-1)
        ax = fig.add_subplot(1, 2, 2)
        ax.imshow(frame)

        # Display the resulting frame
        plt.savefig('./outputs_movement/frames_ske/%05d.png' % counter, bbox_inches='tight')
        plt.close(fig)
        counter += 1
    
    # release the cap object
    cap.release()
    # close all windows
    cv2.destroyAllWindows()

def denoise_curve(curve, pct = 0.5):
    ''' denoise ''' 
    # inverse fft
    tx  = np.fft.fft(curve) # fft
    ps = np.abs(np.fft.fft(curve))**2 # spectrum
    # 90%
    Gn_cumsum = np.cumsum(ps[:int(len(ps)/2)])
    Gn_tot = sum(ps[:int(len(ps)/2)])
    # fn_idx = np.nonzero(np.ravel(abs(Gn_cumsum/Gn_tot - pct)< 0.1))[0][0]
    fn_idx = min(range(len(Gn_cumsum/Gn_tot)), key=lambda i: abs(Gn_cumsum[i]/Gn_tot-0.7))
    if fn_idx == 0:
        return curve
    # band pass
    sos = signal.butter(1, fn_idx, 'lowpass', fs=len(curve), output='sos')
    filtered = signal.sosfilt(sos, curve)

    # plot
    flag_plot = False
    if flag_plot:
        fig, axs = plt.subplots(2, 1)
        axs[0].plot(tx)
        axs[0].plot(np.fft.fft(filtered))
        axs[1].plot(curve)
        axs[1].plot(filtered)

    return filtered
    


def cumsum(value, length, pct = 0.6):
    if length == 0:
        return 0
    v_cumsum = np.cumsum(value[:length])
    v_tot = sum(value[:length])
    # fn_idx = np.nonzero(np.ravel(abs(v_cumsum/v_tot - pct)< 0.1))[0]
    fn_idx = min(range(len(v_cumsum/v_tot)), key=lambda i: abs(v_cumsum[i]/v_tot-0.9))

    return fn_idx

def find_peaks_by_area(curve, limb_l, pos_bound = 0.9, neg_bound = 0.9):
    i = j = 0
    value = []
    val_idx = []
    while j<len(curve):

        while (j < len(curve) and curve[j] == 0):
            j += 1
        if j == len(curve):
            break
        i = j

        if curve[j] > 0:
            while (j<len(curve) and curve[j] > 0):
                j += 1
        else:
            while (j<len(curve) and curve[j] < 0):
                j += 1
        
        l = i # i-1 if i-1 == 0 else i
        r = j # j+1 if j+1 == 0 else j
        value.append(sum(curve[l:r]))
        val_idx.append([l,r])

    value = np.array(value)
    pos_value = np.array(sorted(value)[::-1])
    neg_value = np.array(sorted(value))
    idx = list(np.argsort(value)[::-1])

    pos_idx = cumsum(pos_value, sum(value>limb_l/4), pct = pos_bound)
    neg_idx = cumsum(-1*neg_value, sum(value<-limb_l/4), pct = neg_bound)
    total_idx = [val_idx[i] for i in idx[:pos_idx+1]+idx[::-1][:neg_idx+1]]
    return total_idx

def cal_vel(skel, num_frame, vec1, vec2):
    loc = skel[:,vec1[0]]-skel[:,vec1[1]]
    if isinstance(vec2[0], int):
        ang = np.array(list(map(anglefromvec, skel[:,vec2[0]]-skel[:,vec2[1]], skel[:,vec1[0]]-skel[:,vec1[1]])))
    else:
        ang = np.array(list(map(anglefromvec, 0.5*(skel[:,vec2[0][0]] + skel[:,vec2[0][1]])-skel[:,vec2[1]], skel[:,vec1[0]]-skel[:,vec1[1]])))
    return loc, ang

def dics2label(dics):
    new_dics = {}
    b = [[new_dics.update({vv:kk}) for kk,vv in v.items()] for k,v in dics.items()]
    return new_dics

def avgpose2d(pose2d, a,b): # return the coordinate of the middle of joint of index a and b
    return (pose2d[:,a,:]+pose2d[:,b,:])/2.0

def move_rcg(skel):
    ''' All body parts movement recognition
        Args:
        skel (array): Array of 3D skeleton for a certain video.
        out (array): Predicted movement labels
    '''
    midshoulder = avgpose2d(skel, 10, 11)[:,None,:] # 13
    midhip = avgpose2d(skel, 4, 5)[:,None,:] # 14
    skel = np.concatenate([skel, midshoulder, midhip], axis = 1)
    num_frame = len(skel)
    vectors = [[[7, 9], [11, 9]], [[9, 11], [10, 11]], # left lower/upper arm
               [[6, 8], [10, 8]], [[8, 10], [11, 10]], # right lower/upper arm
               [[1, 3], [5, 3]], [[3, 5], [4, 5]], # left lower/upper leg
               [[0, 2], [4, 2]], [[2, 4], [5, 4]], # right lower/upper leg
               [[12, 13], [14, 13]], [[13, 14], [[2, 3], 14]], # head; torso
               [[11, 13], [14, 13]], [[10, 13], [14, 13]], # left shoulder; right shoulder
               [[5, 14], [13, 14]], [[4, 14], [13, 14]]] # left hip; right hip

    # joint angle
    torso_axis_x = np.array([norm(skel[i][4]-skel[i][5]) for i in range(num_frame)])
    torso_axis_y = np.array([norm(skel[i][13] -skel[i][14] ) for i in range(num_frame)])
    torso_axis_z = np.array([norm(np.cross(torso_axis_y[i], torso_axis_x[i])) for i in range(num_frame)])
    origins = [torso_axis_x, torso_axis_y, torso_axis_z]
    # limb length
    static_frames = np.where(np.array([np.sum(s) for s in skel]) == 0)[0]
    offset = 20
    pivot = len(static_frames)+offset
    if pivot < num_frame:
        piv_skel = skel[pivot]
    else:
        piv_skel = skel[0]

    prefixs = ['left lower arm ', 'left upper arm ', 'right lower arm ', 'right upper arm ',
               'left lower leg ', 'left upper leg ', 'right lower leg ', 'right upper leg ',
               'head ', 'torso ', 'left shoulder ', 'right shoulder ', 'left hip ', 'right hip ']
    strings = []
    whole_dicts = {}
    for i in range(14):
        dics = {}
        prefix = prefixs[i]
        dics['ud'] = {0: prefix+'move up', 1: prefix+'move down', 2: ' '}
        dics['lr'] = {0: prefix+'move right', 1: prefix+'move left', 2: ' '}
        dics['fb'] = {0: prefix+'move forward', 1: prefix+'move backward', 2: ' '}
        dics['ang'] = {0: prefix+'extension', 1: prefix+'flexion', 2: ' '}
        whole_dicts.update(dics2label(dics))
        limb_l = np.sqrt(sum((piv_skel[vectors[i][0][0]]-piv_skel[vectors[i][0][1]])**2))
        string = limb_rcg(origins, skel, num_frame, vectors[i][0], vectors[i][1], dics, limb_l)
        strings.extend(string)

    # visualization
    strings = np.array(strings)
    strings = ["\n".join(strings[:,i]) for i in range(num_frame)]
    # write2video(file, strings)
    # os.system("ffmpeg -y -framerate 30 -i outputs_movement/frames/%05d.png outputs_movement/" + file[:-4] + ".mp4")
    # write2videoSkel(file, strings)
    # os.system("ffmpeg -y -framerate 30 -i outputs_movement/frames_ske/%05d.png outputs_movement/" + file[:-4] + ".mp4")

    return strings, whole_dicts

def limb_rcg(origins, skel, num_frame, vec1, vec2, dics, limb_l):
    ''' A certain part movement recognition
        Args:
        origins (array): origin coordinate.
        skel (array): Array of 3D skeleton of a certain body part.
        vec1 (list): target limb.
        vec2 (list): parent limb connected to the target limb.
        dic (dictionary): movement dictionary
    '''
    # calculate location and angle
    loc, ang = cal_vel(skel, num_frame, vec1, vec2)
    # calculate velocity and acceleration
    vel = loc[1:] - loc[:-1]
    vel = np.vstack([vel[0],vel])
    acl = vel[1:] - vel[:-1]
    acl = np.vstack([acl[0],acl])
    ang = ang[1:] - ang[:-1]
    ang = np.hstack([ang[0],ang])
    # calculate velocity projected to x, y, z direction
    vel_fb = smooth(list(map(proj, vel, origins[0])),1)
    vel_fb[np.isnan(vel_fb)] = 0
    vel_ud = smooth(list(map(proj, vel, origins[1])),1)
    vel_ud[np.isnan(vel_ud)] = 0
    vel_lr = smooth(list(map(proj, vel, origins[2])),1)
    vel_lr[np.isnan(vel_lr)] = 0
    vels = {}
    vels['ud'] = vel_ud; vels['fb'] = vel_fb; vels['lr'] = vel_lr
    vels['ang'] = ang
    
    debug = False
    strings_list = []
    for k,dic in dics.items():
        if 'torso' in dic[0] and k == 'ud':
            continue
        if 'head' in dic[0] and (k == 'fb' or k == 'ang'):
            continue
        if 'left shoulder' in dic[0] or 'right shoulder' in dic[0] or \
           'left hip' in dic[0] or 'right hip' in dic[0]:
           if (k == 'lr' or k == 'ang'):
               continue
        vel_proj = vels[k]
        # vel_proj_dns = denoise_curve(vel_proj)
        vel_proj_dns = vel_proj
        peaks_idx = find_peaks_by_area(vel_proj_dns, limb_l, pos_bound = 0.9, neg_bound = 0.9)
        strings = [dic[2]] * len(vel_proj_dns)
        if debug:
            pos_curve = np.zeros((vel_proj_dns.shape))
            neg_curve = np.zeros((vel_proj_dns.shape))
            plt.figure()
            plt.plot(vel_proj_dns)
        for s,e in peaks_idx:
            peak_v = vel_proj_dns[int((s+e)/2)]
            strings[s:e] = [dic[0]]*(e-s) if peak_v > 0 else [dic[1]]*(e-s)
            # test correctness
            if debug:
                if peak_v > 0:
                    pos_curve[s:e] = vel_proj_dns[s:e]
                    plt.plot(list(range(s,e)), vel_proj_dns[s:e])
                else:
                    neg_curve[s:e] = vel_proj_dns[s:e]
                    plt.plot(list(range(s,e)), vel_proj_dns[s:e])
        strings_list.append(strings)
    return strings_list




root_dir = '../Data'
files = os.listdir(root_dir + '/Skeleton')
save_path = root_dir + '/Movements'

for i, file in enumerate(sorted(files)):

    skel_raw = np.load(root_dir + '/Skeleton/' + file, allow_pickle=True).item()
    labels_list = []

    for person in skel_raw.values():
        skel = np.array(list(person.values())).reshape(-1, 13, 3)
        strings, whole_dicts = move_rcg(skel)
        labels = torch.tensor([[whole_dicts[k] for k in s.split('\n')] for s in strings]).T
        labels_list.append(labels)
    
    vid = file.split('/')[-1][:-4]
    torch.save(torch.mean(torch.stack(labels_list).float(), dim=0), os.path.join(save_path, vid))

