import os
import json
import pickle
import argparse
import numpy as np

from matplotlib import pyplot as plt

parser = argparse.ArgumentParser()
parser.add_argument('--source_dir', type=str, default='visualization/static/uestc', help='data directory')
parser.add_argument('--save_dir', type=str, default='saved_anchor_data')
parser.add_argument('--save_vis_dir', type=str, default='vis_test')
parser.add_argument('--concept_list', type=str, nargs='+')
parser.add_argument('--dataset_name', type=str, default='uestc')

args = parser.parse_args()


humanact12_mirror_match = {
    "0": "0",
    "1": "2",
    "2": "1",
    "3": "3",
    "4": "5",
    "5": "4",
    "6": "6",
    "7": "8",
    "8": "7",
    "9": "9",
    "10": "11",
    "11": "10",
    "12": "12",
    "13": "14",
    "14": "13",
    "15": "15",
    "16": "17",
    "17": "16",
    "18": "19",
    "19": "18",
    "20": "21",
    "21": "20",
    "22": "23",
    "23": "22"
}
uestc_mirror_match = {
    "0": "0",
    "1": "1",
    "2": "5",
    "3": "6",
    "4": "7",
    "5": "2",
    "6": "3",
    "7": "4",
    "8": "8",
    "9": "12",
    "10": "13",
    "11": "14",
    "12": "9",
    "13": "10",
    "14": "11",
    "15": "16",
    "16": "15",
    "17": "17"
}
humanact12_limbs = [(0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), \
                    (9, 12), (12, 15), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 14), (14, 17), (17, 19), (19, 21), (21, 23)]
uestc_limbs = [(0, 1), (0, 9), (9, 10), (10, 11), (11, 16), (0, 12), (12, 13), (13, 14), (14, 15), (1, 2), (2, 3), (3, 4), (1, 5), (5, 6), (6, 7), (1, 8), (8, 17)]


def anchor_fit_square(keypoints):
    x = [point[0] for point in keypoints]
    y = [point[1] for point in keypoints]
    z = [point[2] for point in keypoints]
    x = np.array(x)
    y = np.array(y)
    z = np.array(z)

    t = np.arange(0, len(keypoints))
    a_x, b_x, c_x = np.polyfit(t, x, 2)
    a_y, b_y, c_y = np.polyfit(t, y, 2)
    a_z, b_z, c_z = np.polyfit(t, z, 2)

    fit_funx = np.poly1d([a_x, b_x, c_x])
    fit_funy = np.poly1d([a_y, b_y, c_y])
    fit_funz = np.poly1d([a_z, b_z, c_z])

    new_x = fit_funx(t)
    new_y = fit_funy(t)
    new_z = fit_funz(t)

    line_error = np.sqrt((x - new_x)**2 + (y - new_y)**2 + (z - new_z)**2).sum()
    span_error = np.sqrt((new_x[-1] - new_x[0])**2 + (new_y[-1] - new_y[0])**2 + (new_z[-1] - new_z[0])**2)

    return [[float(a_x), float(b_x), float(c_x)], [float(a_y), float(b_y), float(c_y)], [float(a_z), float(b_z), float(c_z)]], line_error, span_error


def anchor_fit_spline(keypoints):
    x = [point[0] for point in keypoints]
    y = [point[1] for point in keypoints]
    z = [point[2] for point in keypoints]
    x = np.array(x)
    y = np.array(y)
    z = np.array(z)

    t = np.arange(0, len(keypoints))
    a_x, b_x, c_x, d_x = np.polyfit(t, x, 3)
    a_y, b_y, c_y, d_y = np.polyfit(t, y, 3)
    a_z, b_z, c_z, d_z = np.polyfit(t, z, 3)

    fit_funx = np.poly1d([a_x, b_x, c_x, d_x])
    fit_funy = np.poly1d([a_y, b_y, c_y, d_y])
    fit_funz = np.poly1d([a_z, b_z, c_z, d_z])

    new_x = fit_funx(t)
    new_y = fit_funy(t)
    new_z = fit_funz(t)

    line_error = np.sqrt((x - new_x)**2 + (y - new_y)**2 + (z - new_z)**2).sum()
    span_error = np.sqrt((new_x[-1] - new_x[0])**2 + (new_y[-1] - new_y[0])**2 + (new_z[-1] - new_z[0])**2)

    return [[float(a_x), float(b_x), float(c_x), float(d_x)], [float(a_y), float(b_y), float(c_y), float(d_y)], [float(a_z), float(b_z), float(c_z), float(d_z)]], line_error, span_error


def anchor_fit_quad(keypoints):
    x = [point[0] for point in keypoints]
    y = [point[1] for point in keypoints]
    z = [point[2] for point in keypoints]
    x = np.array(x)
    y = np.array(y)
    z = np.array(z)

    t = np.arange(0, len(keypoints))
    a_x, b_x, c_x, d_x = np.polyfit(t, x, 4)
    a_y, b_y, c_y, d_y = np.polyfit(t, y, 4)
    a_z, b_z, c_z, d_z = np.polyfit(t, z, 4)

    fit_funx = np.poly1d([a_x, b_x, c_x, d_x])
    fit_funy = np.poly1d([a_y, b_y, c_y, d_y])
    fit_funz = np.poly1d([a_z, b_z, c_z, d_z])

    new_x = fit_funx(t)
    new_y = fit_funy(t)
    new_z = fit_funz(t)

    line_error = np.sqrt((x - new_x)**2 + (y - new_y)**2 + (z - new_z)**2).sum()
    span_error = np.sqrt((new_x[-1] - new_x[0])**2 + (new_y[-1] - new_y[0])**2 + (new_z[-1] - new_z[0])**2)

    return [[float(a_x), float(b_x), float(c_x), float(d_x)], [float(a_y), float(b_y), float(c_y), float(d_y)], [float(a_z), float(b_z), float(c_z), float(d_z)]], line_error, span_error


def visualize(save_vis_dir, total_locations, concept_name, idx):
    for i in range(total_locations.shape[0]):
        x, y, z = total_locations[i, :, 0], total_locations[i, :, 2], -total_locations[i, :, 1]
        
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        
        ax.scatter(x, y, z, c='b', s=50, label='Synthesis keypoints')
        
        for edge in humanact12_limbs:
            ax.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], [z[edge[0]], z[edge[1]]], c='b')

        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title('3D Keypoints Visualization')
        
        ax.legend()
        ax.set_box_aspect((1, 1.5, 2))

        save_path = os.path.join(save_vis_dir, concept_name, str(idx))
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        
        plt.savefig(os.path.join(save_vis_dir, concept_name, str(idx), str(i) + '.png'), dpi=300)


def visualize_anchor_fit(concept, keypoints1, keypoints2, instance_idx):
    assert keypoints1.shape == keypoints2.shape
    for i in range(keypoints1.shape[1]):
        x1, y1, z1 = keypoints1[:, i, 0], keypoints1[:, i, 1], keypoints1[:, i, 2]
        x2, y2, z2 = keypoints2[:, i, 0], keypoints2[:, i, 1], keypoints2[:, i, 2]

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        ax.scatter(x1, y1, z1, c='b', s=50, label='Original keypoints')
        ax.scatter(x2, y2, z2, c='r', s=50, label='Generated keypoints')

        for edge in humanact12_limbs:
            ax.plot([x1[edge[0]], x1[edge[1]]], [y1[edge[0]], y1[edge[1]]], [z1[edge[0]], z1[edge[1]]], c='b')
            ax.plot([x2[edge[0]], x2[edge[1]]], [y2[edge[0]], y2[edge[1]]], [z2[edge[0]], z2[edge[1]]], c='r')

        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title('3D Keypoints Visualization')

        ax.legend()
        ax.set_box_aspect((1, 1, 1))

        if not os.path.exists(os.path.join('vis_anchor_fit', concept, str(instance_idx))):
            os.makedirs(os.path.join('vis_anchor_fit', concept, str(instance_idx)))
        plt.savefig(os.path.join('vis_anchor_fit', concept, str(instance_idx), str(i) + '.png'), dpi=300)
        plt.clf()


def transform_square_to_keypoint(prim_data, duration):
    all_locations = []
    if len(prim_data.shape) == 4:
        for s in range(prim_data.shape[0]):
            tmp_prim_data = prim_data[s]
            total_locations = []
            for i in range(tmp_prim_data.shape[0]):
                locations = []
                for j in range(tmp_prim_data.shape[2] // 3):
                    for t in range(int(duration[s])):
                        x = tmp_prim_data[i][0][j*3] * t ** 2 + tmp_prim_data[i][0][j*3+1] * t + tmp_prim_data[i][0][j*3+2]
                        y = tmp_prim_data[i][1][j*3] * t ** 2 + tmp_prim_data[i][1][j*3+1] * t + tmp_prim_data[i][1][j*3+2]
                        z = tmp_prim_data[i][2][j*3] * t ** 2 + tmp_prim_data[i][2][j*3+1] * t + tmp_prim_data[i][2][j*3+2]
                        locations.append([x, y, z])
                total_locations.append(locations)
            total_locations = np.array(total_locations)
            total_locations = np.transpose(total_locations, (1, 0, 2))
            all_locations.append(total_locations)
        all_locations = np.array(all_locations)
        
        return all_locations
    elif len(prim_data.shape) == 3:
        tmp_prim_data = prim_data
        total_locations = []
        for i in range(tmp_prim_data.shape[0]):
            locations = []
            for j in range(tmp_prim_data.shape[2] // 3):
                for t in range(int(duration)):
                    x = tmp_prim_data[i][0][j*3] * t ** 2 + tmp_prim_data[i][0][j*3+1] * t + tmp_prim_data[i][0][j*3+2]
                    y = tmp_prim_data[i][1][j*3] * t ** 2 + tmp_prim_data[i][1][j*3+1] * t + tmp_prim_data[i][1][j*3+2]
                    z = tmp_prim_data[i][2][j*3] * t ** 2 + tmp_prim_data[i][2][j*3+1] * t + tmp_prim_data[i][2][j*3+2]
                    locations.append([x, y, z])
            total_locations.append(locations)
        total_locations = np.array(total_locations)
        total_locations = np.transpose(total_locations, (1, 0, 2))
        return total_locations


def transform_spline_to_keypoint(prim_data, duration):
    all_locations = []
    if len(prim_data.shape) == 4:
        for s in range(prim_data.shape[0]):
            tmp_prim_data = prim_data[s]
            total_locations = []
            for i in range(tmp_prim_data.shape[0]):
                locations = []
                for j in range(tmp_prim_data.shape[2] // 4):
                    for t in range(int(duration[s])):
                        x = tmp_prim_data[i][0][j*4] * t ** 3 + tmp_prim_data[i][0][j*4+1] * t ** 2 + tmp_prim_data[i][0][j*4+2] * t + tmp_prim_data[i][0][j*4+3]
                        y = tmp_prim_data[i][1][j*4] * t ** 3 + tmp_prim_data[i][1][j*4+1] * t ** 2 + tmp_prim_data[i][1][j*4+2] * t + tmp_prim_data[i][1][j*4+3]
                        z = tmp_prim_data[i][2][j*4] * t ** 3 + tmp_prim_data[i][2][j*4+1] * t ** 2 + tmp_prim_data[i][2][j*4+2] * t + tmp_prim_data[i][2][j*4+3]
                        locations.append([x, y, z])
                total_locations.append(locations)
            total_locations = np.array(total_locations)
            total_locations = np.transpose(total_locations, (1, 0, 2))
            all_locations.append(total_locations)
        all_locations = np.array(all_locations)
        
        return all_locations
    elif len(prim_data.shape) == 3:
        tmp_prim_data = prim_data
        total_locations = []
        for i in range(tmp_prim_data.shape[0]):
            locations = []
            for j in range(tmp_prim_data.shape[2] // 4):
                for t in range(int(duration)):
                    x = tmp_prim_data[i][0][j*4] * t ** 3 + tmp_prim_data[i][0][j*4+1] * t ** 2 + tmp_prim_data[i][0][j*4+2] * t + tmp_prim_data[i][0][j*4+3]
                    y = tmp_prim_data[i][1][j*4] * t ** 3 + tmp_prim_data[i][1][j*4+1] * t ** 2 + tmp_prim_data[i][1][j*4+2] * t + tmp_prim_data[i][1][j*4+3]
                    z = tmp_prim_data[i][2][j*4] * t ** 3 + tmp_prim_data[i][2][j*4+1] * t ** 2 + tmp_prim_data[i][2][j*4+2] * t + tmp_prim_data[i][2][j*4+3]
                    locations.append([x, y, z])
            total_locations.append(locations)
        total_locations = np.array(total_locations)
        total_locations = np.transpose(total_locations, (1, 0, 2))
        return total_locations


def calculate_MPJPE(predictions, ground_truth):
    joint_num, num_frame, _ = predictions.shape
    
    errors = np.linalg.norm(predictions - ground_truth, axis=2)
    mean_errors = np.mean(errors, axis=1)
    mpjpe = np.mean(mean_errors)
    
    return mpjpe


source_dir = os.path.join(args.source_dir)
save_anchor_dir = os.path.join(args.save_dir)
if not os.path.exists(save_anchor_dir):
    os.makedirs(save_anchor_dir)
save_vis_anchor_dir = os.path.join(args.save_vis_dir)
concept_list = args.concept_list
if args.dataset_name == 'humanact12':
    mirror_match = humanact12_mirror_match
    num_joints = 24
elif args.dataset_name == 'uestc':
    mirror_match = uestc_mirror_match
    num_joints = 18

if not os.path.exists(os.path.join(save_anchor_dir, 'concept_anchor_specific.pkl')):
    print("concept_anchor_specific.pkl not found, regnerating...")

    concept_anchor_record = {}

    for concept in concept_list:
        concept_path = os.path.join(source_dir, concept)
        concept_anchor_record[concept] = {}
        concept_anchor_record[concept]['anchors'] = []
        concept_anchor_record[concept]['durations'] = []
        concept_anchor_record[concept]['prims'] = []
        concept_anchor_record[concept]['input_keypoints'] = []
        total_instance_num = len(os.listdir(concept_path))
        filtered_instance_num = 0
        
        for instance_name in os.listdir(concept_path):
            instance_path = os.path.join(concept_path, instance_name, instance_name)
            with open(os.path.join(instance_path, 'base.json'), 'rb') as f:
                base_data = json.load(f)
                
            primitives = base_data['primitives']
            input_keypoints = base_data['input_keypoints']
            # remove offset
            input_keypoints = np.array(list(input_keypoints.values())) # (24, T, 3)
            input_keypoints = input_keypoints - [255.5, 255.5, 255.5]
            prim_num = len(primitives['0'].keys())
            for i in range(num_joints):
                for j in range(prim_num):
                    primitives[str(i)][str(j)][0] = np.array(primitives[str(i)][str(j)][0])
                    primitives[str(i)][str(j)][0][:, 3] = primitives[str(i)][str(j)][0][:, 3] - 255.5

            if prim_num < 2:
                continue
            else:
                input_keypoints_record = []
                anchor_record = []
                duration_record = []
                prim_record = []
                filtered_instance_num += 1
                for i in range(num_joints):
                    joint_anchor_record = []
                    joint_duration_record = []
                    joint_prim_record = []
                    joint_keypoints_record = []
                    keypoints_offset = 0
                    for j in range(0, prim_num):
                        prim = np.array(primitives[str(i)][str(j)][0].copy())
                        joint_prim_record.append(prim)
                        anchor = np.array(primitives[str(i)][str(j)][0].copy())[:, 3]
                        joint_anchor_record.append(anchor)
                        if j == prim_num - 1:
                            duration = primitives[str(i)][str(j)][1] - 1
                        else:
                            duration = primitives[str(i)][str(j)][1]
                        joint_duration_record.append(duration)
                        joint_keypoints_record.append(input_keypoints[i][keypoints_offset:min(keypoints_offset+duration+1, input_keypoints[i].shape[0])].copy())
                        keypoints_offset += duration
                        if j == prim_num - 1:
                            joint_anchor_record.append(input_keypoints[i][keypoints_offset].copy())

                    joint_prim_record = np.array(joint_prim_record)
                    joint_anchor_record = np.array(joint_anchor_record)
                    joint_duration_record = np.array(joint_duration_record)

                    prim_record.append(joint_prim_record)
                    anchor_record.append(joint_anchor_record)
                    duration_record.append(joint_duration_record)
                    input_keypoints_record.append(joint_keypoints_record)
                prim_record = np.array(prim_record)
                anchor_record = np.array(anchor_record)
                duration_record = np.array(duration_record)
                concept_anchor_record[concept]["prims"].append(prim_record)
                concept_anchor_record[concept]["anchors"].append(anchor_record)
                concept_anchor_record[concept]["durations"].append(duration_record)
                concept_anchor_record[concept]["input_keypoints"].append(input_keypoints_record)

                ### Mirror
                anchor_record = []
                duration_record = []
                prim_record = []
                input_keypoints_record = []
                for i in range(num_joints):
                    joint_anchor_record = []
                    joint_duration_record = []
                    joint_prim_record = []
                    joint_keypoints_record = []
                    keypoints_offset = 0
                    for j in range(0, prim_num):
                        if j == prim_num - 1:
                            duration = primitives[str(i)][str(j)][1] - 1
                        else:
                            duration = primitives[str(i)][str(j)][1]
                        joint_duration_record.append(duration)
                        root_prim = np.array(primitives['0'][str(j)][0].copy())
                        root_anchor = np.array(primitives['0'][str(j)][0].copy())[:, 3]
                        root_input_keypoints = input_keypoints[0][keypoints_offset:min(keypoints_offset+duration+1, input_keypoints[i].shape[0])].copy()
                        prim_mirror = np.array(primitives[mirror_match[str(i)]][str(j)][0].copy())
                        prim_mirror[0] = 2 * root_prim[0] - prim_mirror[0]
                        joint_prim_record.append(prim_mirror)
                        anchor_mirror = np.array(primitives[mirror_match[str(i)]][str(j)][0].copy())[:, 3]
                        anchor_mirror[0] = 2 * root_anchor[0] - anchor_mirror[0]
                        joint_anchor_record.append(anchor_mirror)
                        input_keypoints_mirror = input_keypoints[int(mirror_match[str(i)])][keypoints_offset:min(keypoints_offset+duration+1, input_keypoints[i].shape[0])].copy()
                        input_keypoints_mirror[:, 0] = 2 * root_input_keypoints[:, 0] - input_keypoints_mirror[:, 0]
                        joint_keypoints_record.append(input_keypoints_mirror)
                        keypoints_offset += duration
                        if j == prim_num - 1:
                            joint_anchor_record.append(input_keypoints[int(mirror_match[str(i)])][keypoints_offset].copy())

                    joint_prim_record = np.array(joint_prim_record)
                    joint_anchor_record = np.array(joint_anchor_record)
                    joint_duration_record = np.array(joint_duration_record)

                    prim_record.append(joint_prim_record)
                    anchor_record.append(joint_anchor_record)
                    duration_record.append(joint_duration_record)
                    input_keypoints_record.append(joint_keypoints_record)
                prim_record = np.array(prim_record)
                anchor_record = np.array(anchor_record)
                duration_record = np.array(duration_record)
                concept_anchor_record[concept]["prims"].append(prim_record)
                concept_anchor_record[concept]["anchors"].append(anchor_record)
                concept_anchor_record[concept]["durations"].append(duration_record)
                concept_anchor_record[concept]["input_keypoints"].append(input_keypoints_record)
        
        print(concept, "Anchor num over 3 instance ratio:", filtered_instance_num, total_instance_num, filtered_instance_num / total_instance_num)

    with open(os.path.join(save_anchor_dir, 'concept_anchor_specific.pkl'), 'wb') as f:
        pickle.dump(concept_anchor_record, f)
    print("concept_anchor_specific.pkl saved")

else:
    with open(os.path.join(save_anchor_dir, 'concept_anchor_specific.pkl'), 'rb') as f:
        concept_anchor_record = pickle.load(f)

concept_anchor_fit_param = {}
for concept in concept_anchor_record.keys():
    concept_prims = concept_anchor_record[concept]["prims"]
    concept_anchor = concept_anchor_record[concept]["anchors"]
    concept_duration = concept_anchor_record[concept]["durations"]
    concept_input_keypoints = concept_anchor_record[concept]["input_keypoints"]
    for i in range(len(concept_prims)):
        concept_prims[i] = concept_prims[i].reshape(concept_prims[i].shape[0], concept_prims[i].shape[1], -1)
        concept_duration[i] = concept_duration[i].reshape(concept_duration[i].shape[0], concept_duration[i].shape[1], -1)
        concept_prims[i] = np.concatenate((concept_prims[i], concept_duration[i]), axis=2)

    concept_anchor_fit_param[concept] = {}
    concept_anchor_fit_param[concept]['prims'] = concept_prims
    concept_anchor_fit_param[concept]['anchors'] = concept_anchor
    concept_anchor_fit_param[concept]['input_keypoints'] = concept_input_keypoints
    
with open(os.path.join(save_anchor_dir, 'concept_anchor_fit_param.pkl'), 'wb') as f:
    pickle.dump(concept_anchor_fit_param, f)
