import os, sys
import warnings

import torch.distributed

warnings.filterwarnings("ignore")

import argparse
import math, random
import torch
import time

from config import PARAMS_CONFIG
from data import get_train_val_test_data
from models_u import TransformerSeq
from trainer_eval import get_expert_choices
import datetime
from utils import (
    get_params,
    set_up_env,
    get_optimizer_and_scheduler,
    load_checkpoint,
    save_checkpoint,
    create_exp_dir,
    freeze_gate_weight,
    Logger,
    set_freq_optimal_search,
)

import matplotlib.pyplot as plt
import seaborn as sns



def launch(
    env_params,
    model_params,
    adapt_span_params,
    optim_params,
    data_params,
    trainer_params,
):
    # global val
    best_val_loss = None
    # ENVIRONMENT (device, distributed, etc.)
    set_up_env(env_params)
    device = env_params["device"]
    distributed = env_params["distributed"]

    if distributed == False or env_params["rank"] == 0:
        print("data_params:\t", data_params)
        print("model_params:\t", model_params)
        print("optim_params:\t", optim_params)
        print("trainer_params:\t", trainer_params)
        print("adapt_span_params:\t", adapt_span_params)

    # DATA
    train_data, val_data, test_data = get_train_val_test_data(
        data_params=data_params,
        env_params=env_params,
        batch_size=trainer_params["batch_size"],
        device=device,
    )

    # MODEL
    model = TransformerSeq(
        vocab_size=data_params["vocab_size"],
        **model_params,
        adapt_span_params=adapt_span_params,
        # gate_hook=gate_hook,
    )
    print(model)
    if distributed:
        local_rank = env_params["local_rank"]
        model = model.to(device)
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
        )
    else:
        model = torch.nn.DataParallel(model)
        model = model.to(device)

    # OPTIMIZER AND SCHEDULER
    optimizer, scheduler = get_optimizer_and_scheduler(
        model=model, optim_params=optim_params
    )

    # create logger
    logger = Logger()
    # fold_name = trainer_params["checkpoint_path"].split("/")[-1].split(".")[0]
    # folder_path = "/".join(trainer_params["checkpoint_path"].split("/")[:-1])
    # logging = create_exp_dir(f"{folder_path}/experiments/{fold_name}")
    # # log paramters
    # logging(f"Training Parameters:\n {trainer_params}")
    # logging(f"Models Parameters:\n {model_params}")
    # # logging time
    # current_time = datetime.datetime.now()
    # logging(str(current_time))
    # # log model
    # logging(str(model))
    # logging(f"Total of Parameters: {sum(p.numel() for p in model.parameters())}")
    # logging(
    #     f"Total of Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
    # )
    # resume training from last checkpoint if exists
    iter_init = load_checkpoint(
        trainer_params["checkpoint_path"],
        model,
        optimizer,
        scheduler,
        logger,
        distributed,
    )
    # fix gate
    if model_params["smoe_dropout"]:
        freeze_gate_weight(model)
    # calculate time
    # eval model
    if trainer_params["full_eval_mode"]:
        # evaluate the model on test data
        with torch.no_grad():
            Y, auxiliaries = get_expert_choices(
                model,
                test_data,
                model_params["block_size"],
                model_params["hidden_size"],
            )
            # if distributed:
            #     # collect results into rank0
            #     stats = torch.tensor([loss_val, loss_test]).to(device)
            #     torch.distributed.reduce(stats, 0)
            #     if env_params["rank"] == 0:
            #         loss_val = stats[0] / env_params["world_size"]
            #         loss_test = stats[1] / env_params["world_size"]
            #     else:
            #         return

            # # print('Test BPC: {:.4f}'.format(loss_test / math.log(2)))
            # if ("enwik8" in data_params["data_path"]) or (
            #     "text8" in data_params["data_path"]
            # ):
            #     logging("Val: {:.3f} BPC".format(loss_val / math.log(2)))
            #     logging("Test: {:.3f} BPC".format(loss_test / math.log(2)))
            # else:
            #     logging("Val: {:.3f} PPL".format(math.exp(loss_val)))
            #     logging("Test: {:.3f} PPL".format(math.exp(loss_test)))
        # import pdb;pdb.set_trace()
        for layer_id, (layer, (all_scores, input_copy)) in enumerate(zip(model.module.layers, auxiliaries)):
            weight = layer.smoe.gate.gate.weight

            # NGOC: change this! customize kmeans parameters
            # distance can be "euclidean" or "cosine"
            # note that cosine distance requires normalization
            normalize = True
            distance = "cosine"

            centroids = kmeans(input_copy, weight.shape[0], normalize=normalize, max_iter=100, distance=distance)
            if normalize:
                weight = weight / torch.linalg.norm(weight, dim=-1, keepdim=True)
            
            # align
            if distance == "cosine":
                distances = torch.matmul(weight, centroids.T)
                distances = greedy_permute(distances, is_max=True)
            else:
                distances = torch.cdist(weight, centroids)
                distances = greedy_permute(distances, is_max=False)

            plt.figure(figsize=(20, 15))
            sns.heatmap(distances.detach().cpu().numpy(), annot=True)
            plt.gca().set(xlabel="centroids", ylabel="features")
            fname = trainer_params["checkpoint_path"].replace(".pt", f"_{distance}" + ("_normalize" if normalize else "") + f"_{layer_id}.png")
            plt.savefig(fname)


def kmeans(inputs, num_clusters, normalize=False, max_iter=100, distance="euclidean"):
    assert distance in ["euclidean", "cosine"], "Invalid distance metric!"
    assert distance == "euclidean" or normalize, "Normalization is required for cosine distance metric!"

    if normalize:
        inputs = inputs / torch.linalg.norm(inputs, dim=-1, keepdim=True)
    centroids = inputs[:num_clusters, ...].clone()

    for _ in range(max_iter):
        if distance == "cosine":
            distances = 1 - torch.matmul(inputs, centroids.T)
        else:
            distances = torch.cdist(inputs, centroids)

        new_centroids = torch.empty_like(centroids)
        cluster_ids = torch.argmin(distances, dim=-1)
        for cluster_id in range(num_clusters):
            cluster_mask = cluster_ids == cluster_id
            if cluster_mask.sum() == 0:
                continue
            new_centroids[cluster_id] = inputs[cluster_mask].mean(dim=0)

        if normalize:
            new_centroids = new_centroids / torch.linalg.norm(new_centroids, dim=-1, keepdim=True)

        # early termination if centroids do not change
        if torch.allclose(centroids, new_centroids, atol=1e-5):
            break

        centroids = new_centroids

    else:
        print("Iteration limit reached!")

    return centroids


def greedy_permute(distances, is_max=True):
    # permute columns such that diagonal is max (if is_max, and min otherwise)
    num_clusters = distances.shape[1]
    permuted_distances = torch.empty_like(distances)
    distances_copy = distances.clone()
    
    for i in range(num_clusters):
        if is_max:
            idx = torch.argmax(distances_copy[i])
            distances_copy[:, idx] = -float("inf")
        else:
            idx = torch.argmin(distances_copy[i])
            distances_copy[:, idx] = float("inf")

        permuted_distances[:, i] = distances[:, idx]
    
    return permuted_distances

if __name__ == "__main__":
    launch(**get_params(params_config=PARAMS_CONFIG))
