import os
import cmath
import torch
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
from torch.optim.lr_scheduler import LambdaLR


ntu_limbs = [(0, 1), (1, 20), (2, 20), (3, 2), (4, 20), (5, 4), (6, 5), (7, 6), (8, 20), (9, 8), \
            (10, 9), (11, 10), (12, 0), (13, 12), (14, 13), (15, 14), (16, 0), (17, 16), (18, 17), (19, 18), (21, 22), (22, 7), (23, 24), (24, 11)]

humanact12_limbs = [(0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), \
                    (9, 12), (12, 15), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 14), (14, 17), (17, 19), (19, 21), (21, 23)]

def lr_decay_mine(optimizer, lr_now, gamma):
    lr = lr_now * gamma
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


def orth_project(cam, pts):
    """

    :param cam: b*[s,tx,ty]
    :param pts: b*k*3
    :return:
    """
    s = cam[:, 0:1].unsqueeze(1).repeat(1, pts.shape[1], 2)
    T = cam[:, 1:].unsqueeze(1).repeat(1, pts.shape[1], 1)

    return torch.mul(s, pts[:, :, :2] + T)


def opt_cam(x, x_target):
    """
    :param x: N K 3 or  N K 2
    :param x_target: N K 3 or  N K 2
    :return:
    """
    if x_target.shape[2] == 2:
        vis = torch.ones_like(x_target[:, :, :1])
    else:
        vis = (x_target[:, :, :1] > 0).float()
    vis[:, :2] = 0
    xxt = x_target[:, :, :2]
    xx = x[:, :, :2]
    x_vis = vis * xx
    xt_vis = vis * xxt
    num_vis = torch.sum(vis, dim=1, keepdim=True)
    mu1 = torch.sum(x_vis, dim=1, keepdim=True) / num_vis
    mu2 = torch.sum(xt_vis, dim=1, keepdim=True) / num_vis
    xmu = vis * (xx - mu1)
    xtmu = vis * (xxt - mu2)

    eps = 1e-6 * torch.eye(2).float().cuda()
    Ainv = torch.inverse(torch.matmul(xmu.transpose(1, 2), xmu) + eps.unsqueeze(0))
    B = torch.matmul(xmu.transpose(1, 2), xtmu)
    tmp_s = torch.matmul(Ainv, B)
    scale = ((tmp_s[:, 0, 0] + tmp_s[:, 1, 1]) / 2.0).unsqueeze(1)

    scale = torch.clamp(scale, 0.7, 10)
    trans = mu2.squeeze(1) / scale - mu1.squeeze(1)
    opt_cam = torch.cat([scale, trans], dim=1)
    return opt_cam


def get_dct_matrix(N):
    dct_m = np.eye(N)
    for k in np.arange(N):
        for i in np.arange(N):
            w = np.sqrt(2 / N)
            if k == 0:
                w = np.sqrt(1 / N)
            dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N)
    idct_m = np.linalg.inv(dct_m)
    return dct_m, idct_m


def get_lr_scheduler(lr_policy, optimizer, max_iter=None):
    if lr_policy['name'] == "Poly":
        assert max_iter > 0
        num_groups = len(optimizer.param_groups)

        def lambda_f(cur_iter):
            return (1 - (cur_iter * 1.0) / max_iter) ** lr_policy['power']

        scheduler = LambdaLR(optimizer, lr_lambda=[lambda_f] * num_groups)
    else:
        raise NotImplementedError("lr policy not supported")

    return scheduler


def transform_anchor_to_keypoint(trans_data, duration):
    all_locations = []
    for s in range(trans_data.shape[0]):
        tmp_trans_data = trans_data[s]
        total_locations = []
        for i in range(trans_data.shape[1]):
            locations = []
            for j in range(tmp_trans_data.shape[2] // 3):
                for t in range(int(duration[s])):
                    x = tmp_trans_data[i][0][j*3] * t ** 2 + tmp_trans_data[i][0][j*3+1] * t + tmp_trans_data[i][0][j*3+2]
                    y = tmp_trans_data[i][1][j*3] * t ** 2 + tmp_trans_data[i][1][j*3+1] * t + tmp_trans_data[i][1][j*3+2]
                    z = tmp_trans_data[i][2][j*3] * t ** 2 + tmp_trans_data[i][2][j*3+1] * t + tmp_trans_data[i][2][j*3+2]
                    locations.append([x, y, z])
            total_locations.append(locations)
        total_locations = np.array(total_locations)
        total_locations = np.transpose(total_locations, (1, 0, 2))
        all_locations.append(total_locations)
    all_locations = np.array(all_locations)
    
    return all_locations


def transform_trans_to_keypoint(trans_data, duration):
    all_locations = []
    for s in range(trans_data.shape[0]):
        tmp_trans_data = trans_data[s]
        total_locations = []
        for i in range(trans_data.shape[1]):
            locations = []
            for j in range(tmp_trans_data.shape[2] // 4):
                for t in range(int(duration[s])):
                    x = tmp_trans_data[i][0][j*4] * t ** 3 + tmp_trans_data[i][0][j*4+1] * t ** 2 + tmp_trans_data[i][0][j*4+2] * t + tmp_trans_data[i][0][j*4+3]
                    y = tmp_trans_data[i][1][j*4] * t ** 3 + tmp_trans_data[i][1][j*4+1] * t ** 2 + tmp_trans_data[i][1][j*4+2] * t + tmp_trans_data[i][1][j*4+3]
                    z = tmp_trans_data[i][2][j*4] * t ** 3 + tmp_trans_data[i][2][j*4+1] * t ** 2 + tmp_trans_data[i][2][j*4+2] * t + tmp_trans_data[i][2][j*4+3]
                    locations.append([x, y, z])
            total_locations.append(locations)
        total_locations = np.array(total_locations)
        total_locations = np.transpose(total_locations, (1, 0, 2))
        all_locations.append(total_locations)
    all_locations = np.array(all_locations)
    
    return all_locations


def transform_torch(trans_data, duration):
    bs = trans_data.shape[0]
    max_duration = max(duration[:, 0])
    generated_trans_keypoint = torch.zeros((bs, int(max_duration), trans_data.shape[1], 3)).cuda()
    for j in range(int(max_duration)):
        generated_trans_keypoint[:, j, :, 0] = trans_data[:, :, 0, 0] * j ** 3 + trans_data[:, :, 0, 1] * j ** 2 + trans_data[:, :, 0, 2] * j + trans_data[:, :, 0, 3]
        generated_trans_keypoint[:, j, :, 1] = trans_data[:, :, 1, 0] * j ** 3 + trans_data[:, :, 1, 1] * j ** 2 + trans_data[:, :, 1, 2] * j + trans_data[:, :, 1, 3]
        generated_trans_keypoint[:, j, :, 2] = trans_data[:, :, 2, 0] * j ** 3 + trans_data[:, :, 2, 1] * j ** 2 + trans_data[:, :, 2, 2] * j + trans_data[:, :, 2, 3]
    mask = torch.arange(int(max_duration)).unsqueeze(0).unsqueeze(-1).cuda() >= duration[:, 0].unsqueeze(1).unsqueeze(1)
    mask = mask.unsqueeze(2).repeat(1, 1, trans_data.shape[1], 3)
    generated_trans_keypoint[mask] = 0

    return generated_trans_keypoint


def visualize_single_video(save_vis_dir, limbs, keypoints_opt, sample_idx):
    if not os.path.exists(os.path.join(save_vis_dir)):
        os.makedirs(os.path.join(save_vis_dir))
    
    def plot_frame(frame_data, frame_idx=0):
        x_opt, y_opt, z_opt = -frame_data[..., 0], frame_data[..., 2], -frame_data[..., 1]
        ax.scatter(x_opt, y_opt, z_opt, c='b', s=50)
        ax.set_title(str(frame_idx))
        for edge in limbs:
            ax.plot([x_opt[edge[0]], x_opt[edge[1]]], [y_opt[edge[0]], y_opt[edge[1]]], [z_opt[edge[0]], z_opt[edge[1]]], c='b')
    
    def init():
        ax.set_xlim(-100, 100)
        ax.set_ylim(-100, 100)
        ax.set_zlim(-100, 100)
        # ax.set_box_aspect((1, 1, 1))
        return ax
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    num_frame = keypoints_opt.shape[0]

    def update(frame):
        frame_data = keypoints_opt[frame]
        ax.clear()
        init()
        plot_frame(frame_data, frame_idx=frame)
        # ax.legend()
    
    ani = animation.FuncAnimation(fig, update, frames=num_frame, interval=50)
    writer = animation.FFMpegWriter(fps=5, bitrate=1800)
    ani.save(os.path.join(save_vis_dir, str(sample_idx)+'.gif'), writer=writer)
    plt.close()
    plt.clf()


def visualize_sequence(save_vis_dir, limbs, keypoints, sample_idx):
    if not os.path.exists(os.path.join(save_vis_dir)):
        os.makedirs(os.path.join(save_vis_dir))
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    # ax.set_xlim(-200, 200)
    # ax.set_ylim(-400, 400)
    # ax.set_zlim(-100, 100)
    for i in range(keypoints.shape[0]):
        x, y, z = -keypoints[i, :, 0], keypoints[i, :, 2], -keypoints[i, :, 1]

        ax.scatter(x, y, z, c='b', s=50, label='transition gt')

        for edge in limbs:
            ax.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], [z[edge[0]], z[edge[1]]], c='b')
        
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.set_title('3D Keypoints Visualization')
    # ax.legend()

    plt.savefig(os.path.join(save_vis_dir, sample_idx+'.png'), dpi=300)
    plt.close()
    plt.clf()


def visualize_single(save_vis_dir, limbs, keypoints, concept_name, sample_idx):
    if not os.path.exists(os.path.join(save_vis_dir, concept_name, str(sample_idx))):
        os.makedirs(os.path.join(save_vis_dir, concept_name, str(sample_idx)))

    for i in range(keypoints.shape[0]):
        x_init, y_init, z_init = -keypoints[i, :, 0], keypoints[i, :, 2], -keypoints[i, :, 1]

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.set_xlim(-100, 100)
        ax.set_ylim(-100, 100)
        ax.set_zlim(-100, 100)

        ax.scatter(x_init, y_init, z_init, c='g', s=50, label='transition gt')

        for edge in limbs:
            ax.plot([x_init[edge[0]], x_init[edge[1]]], [y_init[edge[0]], y_init[edge[1]]], [z_init[edge[0]], z_init[edge[1]]], c='g')
        
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title('3D Keypoints Visualization')

        ax.legend()
        ax.set_box_aspect((1, 1, 1))

        plt.savefig(os.path.join(save_vis_dir, concept_name, str(sample_idx), str(i) + '.png'), dpi=300)
        plt.close()
        plt.clf()
        
        
def visualize_anchor(save_vis_dir, limbs, keypoints, sample_idx):
    if not os.path.exists(os.path.join(save_vis_dir, str(sample_idx))):
        os.makedirs(os.path.join(save_vis_dir, str(sample_idx)))

    for i in range(keypoints.shape[0]):
        x_init, y_init, z_init = -keypoints[i, :, 0], keypoints[i, :, 2], -keypoints[i, :, 1]

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.set_xlim(-100, 100)
        ax.set_ylim(-100, 100)
        ax.set_zlim(-100, 100)

        ax.scatter(x_init, y_init, z_init, c='g', s=50, label='transition gt')

        for edge in limbs:
            ax.plot([x_init[edge[0]], x_init[edge[1]]], [y_init[edge[0]], y_init[edge[1]]], [z_init[edge[0]], z_init[edge[1]]], c='g')
        
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title('3D Keypoints Visualization')

        ax.legend()
        ax.set_box_aspect((1, 1, 1))

        plt.savefig(os.path.join(save_vis_dir, str(sample_idx), str(i) + '.png'), dpi=300)
        plt.close()
        plt.clf()


def visualize_together2(save_vis_dir, limbs, keypoints_init, keypoints_opt, concept_name, sample_idx, start_idx):
    if not os.path.exists(os.path.join(save_vis_dir, concept_name, str(sample_idx))):
        os.makedirs(os.path.join(save_vis_dir, concept_name, str(sample_idx)))

    for i in range(keypoints_init.shape[0]):
        x_init, y_init, z_init = -keypoints_init[i, :, 0], keypoints_init[i, :, 2], -keypoints_init[i, :, 1]
        x_opt, y_opt, z_opt = -keypoints_opt[i, :, 0], keypoints_opt[i, :, 2], -keypoints_opt[i, :, 1]

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        ax.scatter(x_init, y_init, z_init, c='r', s=50, label='transition gt')
        ax.scatter(x_opt, y_opt, z_opt, c='g', s=50, label='transition init')

        for edge in limbs:
            ax.plot([x_init[edge[0]], x_init[edge[1]]], [y_init[edge[0]], y_init[edge[1]]], [z_init[edge[0]], z_init[edge[1]]], c='r')
            ax.plot([x_opt[edge[0]], x_opt[edge[1]]], [y_opt[edge[0]], y_opt[edge[1]]], [z_opt[edge[0]], z_opt[edge[1]]], c='g')
        
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title('3D Keypoints Visualization')

        ax.legend()
        ax.set_box_aspect((1, 1, 1))

        plt.savefig(os.path.join(save_vis_dir, concept_name, str(sample_idx), str(start_idx+i) + '.png'), dpi=300)
        plt.close()
        plt.clf()


def calculate_MPJPE(predictions, ground_truth):
    num_frame, joint_num, _ = predictions.shape
    
    errors = np.linalg.norm(predictions - ground_truth, axis=2)
    mean_errors = np.mean(errors, axis=1)
    mpjpe = np.mean(mean_errors, axis=0)
    
    return mpjpe


def anchor_rotate(anchor, theta):
    device = anchor.device
    anchor = anchor[..., [0, 2, 1]]
    theta = theta.squeeze(1)
    rotated_matrices = torch.zeros((theta.shape[0], 3, 3)).to(device)
    rotated_matrices[:, 0, 0] = torch.cos(theta)
    rotated_matrices[:, 0, 1] = -torch.sin(theta)
    rotated_matrices[:, 1, 0] = torch.sin(theta)
    rotated_matrices[:, 1, 1] = torch.cos(theta)
    rotated_matrices[:, 2, 2] = 1
    
    rotated_anchor = torch.matmul(anchor, rotated_matrices.transpose(1, 2))
    rotated_anchor = rotated_anchor[..., [0, 2, 1]]

    return rotated_anchor


def anchor_denormalize(anchor_norm, position, iput_orientation, out_delt_orientation):
    derotated_anchor_norm = anchor_rotate(anchor_norm, -out_delt_orientation)
    denormalized_anchor = derotated_anchor_norm + position.unsqueeze(1)
    denormalized_anchor = anchor_rotate(denormalized_anchor, -iput_orientation)
    
    return denormalized_anchor


def rotated_sequence(anchor_data):
    anchor_num = anchor_data.shape[0]
    anchor_data = anchor_data[..., [0, 2, 1]]
    rho, theta = cmath.polar(complex(anchor_data[-1, 0, 0], anchor_data[-1, 0, 1]))
    # expend theta to anchor_num
    theta = theta * torch.ones((anchor_num, 1))
    rotated_anchor_data = anchor_rotate(anchor_data, -theta)
    rotated_position = rotated_anchor_data[:, 0]
    rotated_position -= rotated_position[0]
    delt_rotated_position = rotated_position
    delt_rotated_position[1:] = rotated_position[1:] - rotated_position[:-1]
    
    return rotated_anchor_data, rotated_position, delt_rotated_position, -theta


def norm_sequence(anchor_data, theta):
    bs = anchor_data.shape[0]
    rotated_anchor = []
    delt_rotated_position = []
    for i in range(bs):
        anchor_data[i] = anchor_data[i] - anchor_data[i, -2, 0]
        tmp_rotated_anchor = anchor_rotate(anchor_data[i], theta[i].unsqueeze(0))
        rotated_anchor.append(tmp_rotated_anchor)
        delt_rotated_position.append(tmp_rotated_anchor[-1, 0] - tmp_rotated_anchor[-2, 0])
    rotated_anchor = torch.stack(rotated_anchor, dim=0).to(anchor_data.device)
    delt_rotated_position = torch.stack(delt_rotated_position, dim=0).to(anchor_data.device)
    
    return rotated_anchor, delt_rotated_position


def calculate_activation_statistics(activations):
        activations = activations.detach().cpu().numpy() # activations: (1190 samples, 30 features)
        mu = np.mean(activations, axis=0)
        sigma = np.cov(activations, rowvar=False)
        return mu, sigma
    

def get_orientation(keypoint_sequence):
    keypoint_sequence_norm = keypoint_sequence.copy()
    keypoint_sequence_norm = keypoint_sequence_norm - keypoint_sequence_norm[:, 0:1]
    keypoint_sequence_norm = keypoint_sequence_norm[..., [0, 2, 1]]
    vector_1 = keypoint_sequence_norm[:, 1] - keypoint_sequence_norm[:, 3]
    vector_2 = keypoint_sequence_norm[:, 2] - keypoint_sequence_norm[:, 3]
    orientation_gather = []
    rotated_keypoint_sequence = []
    reverse_record = []
    for i in range(vector_1.shape[0]):
        cross_product = np.cross(vector_1[i], vector_2[i])
        norm = np.linalg.norm(cross_product, axis=0)
        cos_theta = cross_product[1] / norm
        sin_theta = cross_product[0] / norm
        if cross_product[1] > 0:
            reverse_record.append(1)
            keypoint_sequence_norm[i, :, 0] = -keypoint_sequence_norm[i, :, 0]
            keypoint_sequence_norm[i, :, 1] = -keypoint_sequence_norm[i, :, 1]
            theta = np.arcsin(sin_theta)
        else:
            reverse_record.append(0)
            theta = -np.arcsin(sin_theta)
            
        orientation_gather.append(theta)

        rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                                    [np.sin(theta), np.cos(theta), 0],
                                    [0, 0, 1]])
        
        tmp_rotated_gather = np.dot(rotation_matrix, keypoint_sequence_norm[i].T).T
        rotated_keypoint_sequence.append(tmp_rotated_gather)
    rotated_keypoint_sequence = np.array(rotated_keypoint_sequence)
    rotated_keypoint_sequence = rotated_keypoint_sequence[..., [0, 2, 1]]
    orientation_gather = np.array(orientation_gather)
    rotated_keypoint_sequence = rotated_keypoint_sequence - rotated_keypoint_sequence[:, 0:1]
    
    return rotated_keypoint_sequence, orientation_gather, reverse_record
            
            
def calculate_classifier_metrics(predictions, labels):
    TP = sum((predictions == 1) & (labels == 1)).float()
    TP += sum((predictions == 1) & (labels == 0.8)).float()
    TP += sum((predictions == 1) & (labels == 0.6)).float()
    TN = sum((predictions == 0) & (labels == 0)).float()
    FP = sum((predictions == 1) & (labels == 0)).float()
    FN = sum((predictions == 0) & (labels == 1)).float()
    FN += sum((predictions == 0) & (labels == 0.8)).float()
    FN += sum((predictions == 0) & (labels == 0.6)).float()
    
    return TP, TN, FP, FN


def spline_fit(points):
    x = [point[0] for point in points]
    y = [point[1] for point in points]
    z = [point[2] for point in points]
    x = np.array(x)
    y = np.array(y)
    z = np.array(z)
    
    t = np.arange(0, len(points))
    a_x, b_x, c_x, d_x = np.polyfit(t, x, 3)
    a_y, b_y, c_y, d_y = np.polyfit(t, y, 3)
    a_z, b_z, c_z, d_z = np.polyfit(t, z, 3)
    
    fit_funx = np.poly1d([a_x, b_x, c_x, d_x])
    fit_funy = np.poly1d([a_y, b_y, c_y, d_y])
    fit_funz = np.poly1d([a_z, b_z, c_z, d_z])
    
    new_x = fit_funx(t)
    new_y = fit_funy(t)
    new_z = fit_funz(t)
    
    line_error = np.sqrt((x - new_x)**2 + (y - new_y)**2 + (z - new_z)**2).sum()
    span_error = np.sqrt((new_x[-1] - new_x[0])**2 + (new_y[-1] - new_y[0])**2 + (new_z[-1] - new_z[0])**2)
    
    return [[float(a_x), float(b_x), float(c_x), float(d_x)], [float(a_y), float(b_y), float(c_y), float(d_y)], [float(a_z), float(b_z), float(c_z), float(d_z)]], line_error, span_error
            