import os
import argparse
import random
import torch
import torch.nn as nn
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.manifold import TSNE
from pyclustering.cluster.kmeans import kmeans_visualizer
from pyclustering.cluster.kmeans import kmeans as kmeans_pyclustering
from pyclustering.cluster.center_initializer import kmeans_plusplus_initializer
from pyclustering.utils.metric import distance_metric, type_metric

from tensorboardX import SummaryWriter
from utils.util import visualize, visualize_a2m, visualize_video, get_lr_scheduler, collate_helper
from matplotlib import pyplot as plt
from models.model import AnchorClassifier
from models.stgcn import STGCN


parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='../data', help='data directory')
parser.add_argument('--dataset_name', type=str, default='humanact12')
parser.add_argument('--num_joints', type=int, default=24, help='number of joints')
parser.add_argument('-n', '--n_clusters', type=int, default=8, help='number of clusters')
parser.add_argument('--vis_dir', type=str, default='vis_cluster_results', help='visualization directory')
parser.add_argument('-c_list', '--concept_list', type=str, nargs='+', help='action list')
parser.add_argument('-at', '--anchor_type', type=int, default=0, help='anchor type, 0: anchor without start and end, 1: anchor with start and end, 2: anchor with start, 3: anchor with end')
parser.add_argument('--seed', type=int, default=10, help='random')
parser.add_argument('--tsne', action='store_true', help='t-SNE')
parser.add_argument('--adjustment', action='store_true', help='adjustment')
parser.add_argument('--vis_all_samples', action='store_true', help='visualize all samples of each cluster')
parser.add_argument('--classifier', action='store_true', help='classifier')
parser.add_argument('--norm_orientation', action='store_true', help='whether to norm orientation')
parser.add_argument('--distance_type', default='euclidean', help='distance for kmeans clustering')
parser.add_argument('--random_state', type=int, default=0, help='random state')

# classifer
parser.add_argument('--lr', type=float, default=0.00001, help='learning rate')
parser.add_argument('--epoch', type=int, default=5000, help='number of epochs')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
args = parser.parse_args()

action_name_match = {
    0: "warm_up",
    1: "walk",
    2: "run",
    3: "jump",
    4: "drink",
    5: "lift_dumbbell",
    6: "sit",
    7: "eat",
    8: "turn_steer_wheel",
    9: "phone",
    10: "boxing",
    11: "throw"
}

policy = {"name": "Poly", "power": 0.95}
args.policy = policy

humanact12_limbs = [(0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), \
                    (9, 12), (12, 15), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 14), (14, 17), (17, 19), (19, 21), (21, 23)]
uestc_limbs = [(0, 1), (0, 9), (9, 10), (10, 11), (11, 16), (0, 12), (12, 13), (13, 14), (14, 15), (1, 2), (2, 3), (3, 4), (1, 5), (5, 6), (6, 7), (1, 8), (8, 17)]


def mpjpe_distance(x, y):
    if len(x.shape) == 1:
        x = x.reshape(-1, 3)
        y = y.reshape(-1, 3)
    squared_diff = np.sum((x - y) ** 2, axis=1)
    mpjpe = np.mean(np.sqrt(squared_diff))

    return mpjpe


def main():
    for action in args.concept_list:
        if args.dataset_name == 'humanact12':
            action_name = action_name_match[int(action)]
            limbs = humanact12_limbs
        elif args.dataset_name == 'uestc':
            action_name = action
            limbs = uestc_limbs
        print("==> cluster action", action_name, "with", args.n_clusters, "clusters")
        if not os.path.exists(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name)):
            os.makedirs(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name))
        
        # Load anchor data
        anchor_data = np.load(os.path.join(args.data_dir, args.dataset_name, action_name, 'rotated_anchor_data_random.npy'))
        anchor_data_reshape = anchor_data.reshape(anchor_data.shape[0], -1) # (n, num_joints*3)
        
        # sklearn K-Means clustering
        if not os.path.exists(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name, args.distance_type+"_labels.npy")):
            print("labels do not exit, reclustering...")
            if args.distance_type == 'euclidean':
                kmeans = KMeans(n_clusters=args.n_clusters, random_state=args.random_state).fit(anchor_data_reshape)
            elif args.distance_type == 'mpjpe':
                # calculate pre-compute distance using MPJPE distance
                anchor_data_reshape_cluster = anchor_data_reshape.copy()
                anchor_data_reshape_cluster -= anchor_data_reshape_cluster.mean(axis=0)
                distances = np.array([[mpjpe_distance(anchor_data_point, anchor_data_point2) for anchor_data_point2 in anchor_data_reshape_cluster] for anchor_data_point in anchor_data_reshape_cluster])
                kmeans = KMeans(n_clusters=args.n_clusters, random_state=args.random_state, precompute_distances=distances).fit(anchor_data_reshape)
            labels = kmeans.labels_
            np.save(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name, args.distance_type+"_labels.npy"), labels)
            score = silhouette_score(anchor_data_reshape, labels, metric='euclidean')
            print(action, args.n_clusters,"silhouette score:", score)
            interia = kmeans.inertia_
            print(action, args.n_clusters, "inertia:", interia)
        else:
            print("labels loaded...")
            labels = np.load(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name, args.distance_type+"_labels.npy"))
        print("labels:", labels.shape)

        if args.classifier:
            # Prepare data for classifier
            random.seed(args.seed)
            train_idx = []
            test_idx = []
            for i in range(args.n_clusters):
                idx = np.where(labels == i)[0]
                random.shuffle(idx)
                if i == 0:
                    train_data = anchor_data[idx[:int(len(idx)*0.8)]]
                    train_label = np.ones((int(len(idx)*0.8), 1))*i
                    train_idx.append(idx[:int(len(idx)*0.8)])
                    test_data = anchor_data[idx[int(len(idx)*0.8):]]
                    test_label = np.ones((len(idx)-int(len(idx)*0.8), 1))*i
                    test_idx.append(idx[int(len(idx)*0.8):])
                else:
                    train_data = np.concatenate((train_data, anchor_data[idx[:int(len(idx)*0.8)]]), axis=0)
                    train_label = np.concatenate((train_label, np.ones((int(len(idx)*0.8), 1))*i), axis=0)
                    train_idx.append(idx[:int(len(idx)*0.8)])
                    test_data = np.concatenate((test_data, anchor_data[idx[int(len(idx)*0.8):]]), axis=0)
                    test_label = np.concatenate((test_label, np.ones((len(idx)-int(len(idx)*0.8), 1))*i), axis=0)
                    test_idx.append(idx[int(len(idx)*0.8):])
            train_idx = np.concatenate(train_idx, axis=0)
            test_idx = np.concatenate(test_idx, axis=0)

            print("==> classifier train_data:", train_data.shape, "test_data:", test_data.shape)

            np.save(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name, "train_idx.npy"), np.array(train_idx))
            np.save(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name, "test_idx.npy"), np.array(test_idx))

            # Train classifier
            train_data = torch.from_numpy(train_data).float().cuda()
            train_label = torch.from_numpy(train_label).long().cuda()
            test_data = torch.from_numpy(test_data).float().cuda()
            test_label = torch.from_numpy(test_label).long().cuda()
            train_data = train_data.unsqueeze(-1)
            test_data = test_data.unsqueeze(-1)
            print("train_data:", train_data.shape, "test_data:", test_data.shape, "train_idx:", len(train_idx), "test_idx:", len(test_idx))
            train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
            test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
            train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_helper)
            test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_helper)
            layout = args.dataset_name
            model = STGCN(in_channels=3,
                        num_class=args.n_clusters,
                        graph_args={"layout": layout, "strategy": "spatial"},
                        edge_importance_weighting=True).cuda()
            optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=800, gamma=0.2)
            train_accuracy_record = []
            test_accuracy_record = []
            max_accuracy = 0
            train_loss_record = []
            test_loss_record = []
            if not os.path.exists(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name, "batch_size_"+str(args.batch_size)+"_epoch_"+str(args.epoch)+"_lr_"+str(args.lr))):
                os.makedirs(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name, "batch_size_"+str(args.batch_size)+"_epoch_"+str(args.epoch)+"_lr_"+str(args.lr)))
            
            for epoch in range(args.epoch):
                model.train()
                train_dict_loss = {loss: 0 for loss in model.losses}
                for i, batch in enumerate(train_loader):
                    batch = {key: val.cuda() for key, val in batch.items()}
                    optimizer.zero_grad()
                    
                    batch = model(batch)
                    train_mixed_loss, train_losses = model.compute_loss(batch)

                    for key in train_dict_loss.keys():
                        train_dict_loss[key] += train_losses[key]
                    
                    train_mixed_loss.backward()
                    optimizer.step()
                    scheduler.step()
                
                model.eval()
                with torch.no_grad():
                    test_dict_loss = {loss: 0 for loss in model.losses}
                    for i, batch in enumerate(test_loader):
                        batch = {key: val.cuda() for key, val in batch.items()}
                        batch = model(batch)
                        test_mixed_loss, test_losses = model.compute_loss(batch)
                        for key in test_dict_loss.keys():
                            test_dict_loss[key] += test_losses[key]
                for key in train_dict_loss.keys():
                    train_dict_loss[key] /= len(train_loader)
                for key in test_dict_loss.keys():
                    test_dict_loss[key] /= len(test_loader)
                if epoch % 5 == 0:
                    print("Epoch:", epoch, "train_loss:", train_dict_loss, "test_loss:", test_dict_loss)
                train_loss_record.append(train_dict_loss["mixed"])
                test_loss_record.append(test_dict_loss["mixed"])
                if test_dict_loss["accuracy"] > max_accuracy:
                    max_accuracy = test_dict_loss["accuracy"]
                    torch.save(model.state_dict(), os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name, "batch_size_"+str(args.batch_size)+"_epoch_"+str(args.epoch)+"_lr_"+str(args.lr), "classifier.pth"))
                train_accuracy_record.append(train_dict_loss["accuracy"])
                test_accuracy_record.append(test_dict_loss["accuracy"])
            print("test_max_accuracy: {:.4f}, epoch: {:d}, Last 10 average accuracy: {:.4f}".format(max(test_accuracy_record), np.argmax(test_accuracy_record), np.mean(test_accuracy_record[-10:])))
            test_avg_accuracy = []
            for i in range(len(test_accuracy_record)-9):
                test_avg_accuracy.append(np.mean(test_accuracy_record[i:i+10]))
            print("best_test_avg_accuracy: {:.4f}, epoch: {:d}".format(max(test_avg_accuracy), np.argmax(test_avg_accuracy)))

        # t-SNE visualization
        if args.tsne:
            if not os.path.exists(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name, args.distance_type+"_tsne.npy")):
                tsne = TSNE(n_components=2, random_state=0)
                anchor_data_tsne = tsne.fit_transform(anchor_data_reshape)
                np.save(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name, args.distance_type+"_tsne.npy"), anchor_data_tsne)
            else:
                anchor_data_tsne = np.load(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name, args.distance_type+"_tsne.npy"))
            print("t-SNE done")

            # visualize t-SNE
            plt.figure(figsize=(6, 6))
            plt.xticks([])
            plt.yticks([])
            plt.scatter(anchor_data_tsne[:, 0], anchor_data_tsne[:, 1], c=labels, cmap='rainbow', s=10)
            plt.savefig(os.path.join(args.vis_dir+"_"+str(args.n_clusters), action_name, "tsne.png"))
            plt.close()
            plt.clf()
            print("t-SNE visualization done")
        
        # visualize all samples of each cluster
        if args.vis_all_samples:
            for i in range(args.n_clusters):
                idx = np.where(labels == i)[0]
                idx = idx[:50] if len(idx) > 50 else idx
                for j in idx:
                    anchor_data[j, :, [0, 2]] -= anchor_data[j, 0:1, [0, 2]]
                    anchor_data[j] /= 100
                    visualize_a2m(os.path.join(args.vis_dir+"_"+str(args.n_clusters), "anchor_type_"+str(args.anchor_type), "init", str(i)), anchor_data[j], j)
            print("cluster samples visualization done")


if __name__ == '__main__':
    main()
