import os
import torch
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
from torch.optim.lr_scheduler import LambdaLR


humanact12_limbs = [(0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), \
                    (9, 12), (12, 15), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 14), (14, 17), (17, 19), (19, 21), (21, 23)]

humanact12_kinematic_chain = [[0, 1, 4, 7, 10], [0, 2, 5, 8, 11], [0, 3, 6, 9, 12, 15], [9, 13, 16, 18, 20, 22], [9, 14, 17, 19, 21, 23]]

def lr_decay_mine(optimizer, lr_now, gamma):
    lr = lr_now * gamma
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


def orth_project(cam, pts):
    """

    :param cam: b*[s,tx,ty]
    :param pts: b*k*3
    :return:
    """
    s = cam[:, 0:1].unsqueeze(1).repeat(1, pts.shape[1], 2)
    T = cam[:, 1:].unsqueeze(1).repeat(1, pts.shape[1], 1)

    return torch.mul(s, pts[:, :, :2] + T)


def opt_cam(x, x_target):
    """
    :param x: N K 3 or  N K 2
    :param x_target: N K 3 or  N K 2
    :return:
    """
    if x_target.shape[2] == 2:
        vis = torch.ones_like(x_target[:, :, :1])
    else:
        vis = (x_target[:, :, :1] > 0).float()
    vis[:, :2] = 0
    xxt = x_target[:, :, :2]
    xx = x[:, :, :2]
    x_vis = vis * xx
    xt_vis = vis * xxt
    num_vis = torch.sum(vis, dim=1, keepdim=True)
    mu1 = torch.sum(x_vis, dim=1, keepdim=True) / num_vis
    mu2 = torch.sum(xt_vis, dim=1, keepdim=True) / num_vis
    xmu = vis * (xx - mu1)
    xtmu = vis * (xxt - mu2)

    eps = 1e-6 * torch.eye(2).float().cuda()
    Ainv = torch.inverse(torch.matmul(xmu.transpose(1, 2), xmu) + eps.unsqueeze(0))
    B = torch.matmul(xmu.transpose(1, 2), xtmu)
    tmp_s = torch.matmul(Ainv, B)
    scale = ((tmp_s[:, 0, 0] + tmp_s[:, 1, 1]) / 2.0).unsqueeze(1)

    scale = torch.clamp(scale, 0.7, 10)
    trans = mu2.squeeze(1) / scale - mu1.squeeze(1)
    opt_cam = torch.cat([scale, trans], dim=1)
    return opt_cam


def get_dct_matrix(N):
    dct_m = np.eye(N)
    for k in np.arange(N):
        for i in np.arange(N):
            w = np.sqrt(2 / N)
            if k == 0:
                w = np.sqrt(1 / N)
            dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N)
    idct_m = np.linalg.inv(dct_m)
    return dct_m, idct_m


def get_lr_scheduler(lr_policy, optimizer, max_iter=None):
    if lr_policy['name'] == "Poly":
        assert max_iter > 0
        num_groups = len(optimizer.param_groups)

        def lambda_f(cur_iter):
            return (1 - (cur_iter * 1.0) / max_iter) ** lr_policy['power']

        scheduler = LambdaLR(optimizer, lr_lambda=[lambda_f] * num_groups)
    else:
        raise NotImplementedError("lr policy not supported")

    return scheduler


def transform_anchor_to_keypoint(prim_data, duration):
    all_locations = []
    for s in range(prim_data.shape[0]):
        tmp_prim_data = prim_data[s]
        total_locations = []
        for i in range(25):
            locations = []
            for j in range(tmp_prim_data.shape[2] // 3):
                for t in range(int(duration[s])):
                    x = tmp_prim_data[i][0][j*3] * t ** 2 + tmp_prim_data[i][0][j*3+1] * t + tmp_prim_data[i][0][j*3+2]
                    y = tmp_prim_data[i][1][j*3] * t ** 2 + tmp_prim_data[i][1][j*3+1] * t + tmp_prim_data[i][1][j*3+2]
                    z = tmp_prim_data[i][2][j*3] * t ** 2 + tmp_prim_data[i][2][j*3+1] * t + tmp_prim_data[i][2][j*3+2]
                    locations.append([x, y, z])
            total_locations.append(locations)
        total_locations = np.array(total_locations)
        total_locations = np.transpose(total_locations, (1, 0, 2))
        all_locations.append(total_locations)
    all_locations = np.array(all_locations)
    
    return all_locations


def transform_prim_to_keypoint(prim_data, duration):
    all_locations = []
    for s in range(prim_data.shape[0]):
        tmp_prim_data = prim_data[s]
        total_locations = []
        for i in range(25):
            locations = []
            for j in range(tmp_prim_data.shape[2] // 4):
                for t in range(int(duration[s])):
                    x = tmp_prim_data[i][0][j*4] * t ** 3 + tmp_prim_data[i][0][j*4+1] * t ** 2 + tmp_prim_data[i][0][j*4+2] * t + tmp_prim_data[i][0][j*4+3]
                    y = tmp_prim_data[i][1][j*4] * t ** 3 + tmp_prim_data[i][1][j*4+1] * t ** 2 + tmp_prim_data[i][1][j*4+2] * t + tmp_prim_data[i][1][j*4+3]
                    z = tmp_prim_data[i][2][j*4] * t ** 3 + tmp_prim_data[i][2][j*4+1] * t ** 2 + tmp_prim_data[i][2][j*4+2] * t + tmp_prim_data[i][2][j*4+3]
                    locations.append([x, y, z])
            total_locations.append(locations)
        total_locations = np.array(total_locations)
        total_locations = np.transpose(total_locations, (1, 0, 2))
        all_locations.append(total_locations)
    all_locations = np.array(all_locations)
    
    return all_locations


def visualize_input_keypoint(save_vis_dir, total_locations, adjustment_scale_list):
    if not os.path.exists(os.path.join(save_vis_dir)):
        os.makedirs(os.path.join(save_vis_dir))
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    num_frames = total_locations.shape[0]
    
    def init():
        ax.set_xlim(0, 100)
        ax.set_ylim(200, 300)
        ax.set_zlim(-125, 65)
        # ax.set_box_aspect((1, 1, 1))
        return ax
    
    def plot_frame(frame_data, adjustment_scale):
        x, y, z = -frame_data[..., 0], frame_data[..., 2], -frame_data[..., 1]
        ax.scatter(x, y, z, c='b', s=50, label=adjustment_scale)
        for edge in humanact12_limbs:
            ax.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], [z[edge[0]], z[edge[1]]], c='b')
        ax.legend()

    def update(frame):
        frame_data = total_locations[frame]
        adjustment_scale = adjustment_scale_list[frame]
        ax.clear()
        init()
        plot_frame(frame_data, adjustment_scale)
    
    # create animation
    ani = animation.FuncAnimation(fig, update, frames=num_frames, interval=100)
    writer = animation.FFMpegWriter(fps=5, metadata=dict(artist='Me', birate=1800))
    ani.save(os.path.join(save_vis_dir, "adjustment.gif"), writer=writer)
    plt.close()
    plt.clf()


def visualize_video(save_vis_dir, limbs, total_locations, concept_name, sample_idx, adjustment_scale_list):
    if not os.path.exists(os.path.join(save_vis_dir, concept_name, str(sample_idx))):
        os.makedirs(os.path.join(save_vis_dir, concept_name, str(sample_idx)))
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    num_frames = total_locations.shape[0]
    
    def init():
        ax.set_xlim(-100, 100)
        ax.set_ylim(-100, 100)
        ax.set_zlim(-100, 100)
        # ax.set_box_aspect((1, 1, 1))
        return ax
    
    def plot_frame(frame_data, adjustment_scale):
        x, y, z = -frame_data[..., 0], frame_data[..., 2], -frame_data[..., 1]
        ax.scatter(x, y, z, c='b', s=20, label=adjustment_scale)
        for edge in limbs:
            ax.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], [z[edge[0]], z[edge[1]]], c='b')
        ax.legend()

    def update(frame):
        frame_data = total_locations[frame]
        adjustment_scale = adjustment_scale_list[frame]
        ax.clear()
        init()
        plot_frame(frame_data, adjustment_scale)
    
    # create animation
    ani = animation.FuncAnimation(fig, update, frames=num_frames, interval=50)
    writer = animation.FFMpegWriter(fps=15, metadata=dict(artist='Me', birate=1800))
    ani.save(os.path.join(save_vis_dir, concept_name, str(sample_idx), "adjustment.gif"), writer=writer)
    plt.close()
    plt.clf()


def visualize_video2(save_vis_dir, total_locations, concept_name, sample_idx, adjustment_scale_list):
    if not os.path.exists(os.path.join(save_vis_dir, concept_name, str(sample_idx))):
        os.makedirs(os.path.join(save_vis_dir, concept_name, str(sample_idx)))
    
    def plot_frame(frame_data, adjustment_scale):
        x, y, z = -frame_data[..., 0], frame_data[..., 1], -frame_data[..., 2]
        ax.scatter(x, y, z, c='b', s=50, label=adjustment_scale)
        for edge in humanact12_limbs:
            ax.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], [z[edge[0]], z[edge[1]]], c='b')
        ax.legend()
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.set_box_aspect((1, 1, 1))
    num_frames = total_locations.shape[0]

    def update(frame):
        frame_data = total_locations[frame]
        adjustment_scale = adjustment_scale_list[frame]
        ax.clear()
        plot_frame(frame_data, adjustment_scale)
    
    # create animation
    ani = animation.FuncAnimation(fig, update, frames=num_frames, interval=50)
    writer = animation.FFMpegWriter(fps=15, metadata=dict(artist='Me', birate=1800))
    ani.save(os.path.join(save_vis_dir, concept_name, str(sample_idx), "adjustment.gif"), writer=writer)
    plt.close()
    plt.clf()


def visualize_anchor(save_vis_dir, limbs, total_locations, start_idx=0, start_anchor_dix=[], end_anchor_idx=[]):
    if not os.path.exists(os.path.join(save_vis_dir)):
        os.makedirs(os.path.join(save_vis_dir))

    for i in range(total_locations.shape[0]):
        x, y, z = -total_locations[i, :, 0], total_locations[i, :, 2], -total_locations[i, :, 1]
        
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        
        if start_idx in start_anchor_dix:
            ax.scatter(x, y, z, c='r', s=50, label='Start Anchor')
            for edge in limbs:
                ax.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], [z[edge[0]], z[edge[1]]], c='r')
        elif start_idx in end_anchor_idx:
            ax.scatter(x, y, z, c='g', s=50, label='End Anchor')
            for edge in limbs:
                ax.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], [z[edge[0]], z[edge[1]]], c='g')
        else:
            ax.scatter(x, y, z, c='b', s=50, label='Anchor')
            for edge in limbs:
                ax.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], [z[edge[0]], z[edge[1]]], c='b')

        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        
        # ax.set_box_aspect((1, 1, 1))
        ax.set_xlim(0, 100)
        ax.set_ylim(200, 300)
        ax.set_zlim(-125, 65)
        
        plt.savefig(os.path.join(save_vis_dir, str(start_idx+i) + '.png'), dpi=300)
        plt.close()
        plt.clf()


def visualize(save_vis_dir, limbs, total_locations, concept_name, sample_idx, start_idx, start_anchor_dix=[], end_anchor_idx=[]):
    if not os.path.exists(os.path.join(save_vis_dir, concept_name, str(sample_idx))):
        os.makedirs(os.path.join(save_vis_dir, concept_name, str(sample_idx)))

    for i in range(total_locations.shape[0]):
        x, y, z = -total_locations[i, :, 0], total_locations[i, :, 2], -total_locations[i, :, 1]
        
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        
        if start_idx in start_anchor_dix:
            ax.scatter(x, y, z, c='r', s=20, label='Start Anchor')
            for edge in limbs:
                ax.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], [z[edge[0]], z[edge[1]]], c='r')
        elif start_idx in end_anchor_idx:
            ax.scatter(x, y, z, c='g', s=20, label='End Anchor')
            for edge in limbs:
                ax.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], [z[edge[0]], z[edge[1]]], c='g')
        else:
            ax.scatter(x, y, z, c='b', s=20, label='Anchor')
            for edge in limbs:
                ax.plot([x[edge[0]], x[edge[1]]], [y[edge[0]], y[edge[1]]], [z[edge[0]], z[edge[1]]], c='b')

        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        
        # ax.set_box_aspect((1, 1, 1))
        
        ax.set_xlim(-100, 100)
        ax.set_ylim(-100, 100)
        ax.set_zlim(-100, 100)
        
        plt.savefig(os.path.join(save_vis_dir, concept_name, str(sample_idx), str(start_idx+i) + '.png'), dpi=300)
        plt.close()
        plt.clf()
        
        
def visualize_a2m(save_vis_dir, pose, sample_idx):
    if not os.path.exists(os.path.join(save_vis_dir)):
        os.makedirs(os.path.join(save_vis_dir))

    def init():
        # ax.set_xlabel('x')
        # ax.set_ylabel('y')
        # ax.set_zlabel('z')
        ax.set_ylim(-1, 1)
        ax.set_xlim(-1, 1)
        ax.set_zlim(-1, 1)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    init()
    colors = ['red', 'magenta', 'black', 'green', 'blue']
    ax.lines = []
    ax.collections = []
    ax.view_init(elev=110, azim=90)
    for chain, color in zip(humanact12_kinematic_chain, colors):
        ax.plot3D(pose[chain, 0], pose[chain, 1], -pose[chain, 2], linewidth=4.0, color=color)
    # plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_zticklabels([])
    
    plt.savefig(os.path.join(save_vis_dir, str(sample_idx)+".png"), dpi=300)
    plt.close()
    plt.clf()


def calculate_MPJPE(predictions, ground_truth):
    num_frame, joint_num, _ = predictions.shape
    
    errors = np.linalg.norm(predictions - ground_truth, axis=2)
    mean_errors = np.mean(errors, axis=0)
    mpjpe = np.mean(mean_errors)
    
    return mpjpe


def collate_helper(batch):
    databatch = [b[0] for b in batch]
    labelbatch = [b[1][0] for b in batch]

    databatchTensor = torch.stack(databatch, dim=0)
    labelbatchTensor = torch.stack(labelbatch, dim=0)

    batch = {"x": databatchTensor, "y": labelbatchTensor}

    return batch
