import argparse
import torch
import json
import glob
import time
import numpy as np
import os
import matplotlib.pyplot as plt
import random 

def main(args):
    cnt = 0
    metrics = []
    global_ids = []

    if args.metric == "attn_svd":
        metric_checkpoints = sorted(glob.glob(os.path.join(args.metric_dir, "*_attn_svd.json")), key=lambda x: int(x.split("checkpoint-")[-1].split("_attn_svd")[0]))
    elif args.metric == "loss":
        metric_checkpoints = sorted(glob.glob(os.path.join(args.metric_dir, "*_loss.json")), key=lambda x: int(x.split("checkpoint-")[-1].split("_loss")[0]))
    else:
        raise ValueError(f"Incorrect value of metric {args.metric} passed!")
    for ckpt in metric_checkpoints:
        print(f"*** {ckpt} ** Loading metrics...")
        try:
            k = args.k
            data = json.load(open(ckpt))
            if cnt == 0:  # defining order for all datapoints; ensuring proper stacking of tensors
                cnt += 1
                global_ids = [i for i in data.keys()]
                if args.metric == "attn_svd":
                    if k is not None:
                        print(f"sampling top-{k} values from the attn_svd checkpoints")
                        if k == 1:
                            vals = [sorted(data[i], reverse=True)[:k][0] for i in global_ids]
                            metrics.append(torch.tensor(vals))
                        else:
                            vals = [sum(sorted(data[i], reverse=True)[:k]) for i in global_ids]
                            metrics.append(torch.tensor(vals)) 
                    else:
                        metric_trajectories = [data[i] for i in global_ids]
                        max_len = max([len(i) for i in metric_trajectories])
                        padded_trajectories = []  
                        for i in enumerate(metric_trajectories):
                            padding = [0 for _ in range(max_len)]
                            padding[:len(i)] = i 
                            padded_trajectories.append(padding)
                        metrics.append(torch.cat([torch.tensor(padded_trajectory).unsqueeze(-1) for padded_trajectory in padded_trajectories]))

                elif args.metric == "loss":
                    vals = [data[i] for i in global_ids]
                    metrics.append(torch.tensor(vals))
                
                else: 
                    raise ValueError(f"Incorrect value of metric {args.metric} passed!")
            
            else:
                if args.metric == "attn_svd":
                    if k is not None:
                        # metrics.append(torch.tensor([data[id].sort(reverse=True)[:k] for id in global_ids]))
                        if k == 1:
                            vals = [sorted(data[i], reverse=True)[:k][0] for i in global_ids]
                            metrics.append(torch.tensor(vals))
                        else:
                            vals = [sum(sorted(data[i], reverse=True)[:k]) for i in global_ids]
                            metrics.append(torch.tensor(vals)) 
                    else: 
                        metrics.append(torch.tensor([data[id] for id in global_ids]))
                
                elif args.metric == "loss":
                    vals = [data[i] for i in global_ids]
                    metrics.append(torch.tensor(vals))
            
                else: 
                    raise ValueError(f"Incorrect value of metric {args.metric} passed!")

        except:
            print(f"*** {ckpt} ** Could not load metrics.")
            continue

    metrics = torch.stack(metrics).t()

    ## comment out the following code in order to print the attn_svd trajectory for a single datapoint

    # i = random.randint(0, 79999)
    # single_datapoint_vals = metrics[index]
    # id = global_ids[i]
    # get_single_point_plot(single_datapoint_vals, id)

    attn_svd_variance_vals = save_attn_svd_variance(metrics)
    # variance_vals = get_variance(metrics)

    # variance_vals_dict = {}
    # for id, var_val in zip(global_ids, variance_vals):
    #     variance_vals_dict[id] = var_val 
    attn_svd_var_dict = {}
    for id, attn_svd_var in zip(global_ids, attn_svd_variance_vals):
        attn_svd_var_dict[id] = attn_svd_var 
    # print("this si the number of datapoints: ", len(variance_vals_dict))
    # # TODO: add another argument in order to add the anothe rsave path for the file to be saved
    file_path = os.pth.join(args.stability_save_dir_path, "attn_svd_files/attn_svd_trajectory_variance_stability.json")
    os.makedirs(file_path, exist_ok=True)
    with open(file_path, "w") as f:
        json.dump(variance_vals_dict, f, indent=4)

    
    # set nan to 0
    metrics[torch.isnan(metrics)] = 0
    assert metrics.shape[0] == 80000

    # faiss_kmeans_selection(metrics, 80000, global_ids, args)

def faiss_kmeans_selection(features, n, ids, args):

    import faiss
    start_time = time.time()
    kmeans = faiss.Kmeans(features.shape[1], 1000, niter=100, verbose=True)   # original values S2L: niter=20 COINCIDE: niter=100 self.n_components = 100
    kmeans.train(features.numpy())
    
    # get the kmeans cluster labels
    D, I = kmeans.index.search(features.numpy(), 1)
    
    # TODO: add an argument for save path for the file to be saved
    cluster_save_path = os.path.join(args.metric_dir, "attn_svd_top-5_1000_clusters.json")
    cluster_data = {}
    neighbors = []
    for i, neighbor in enumerate(I):
        neighbors.append(neighbor[0].item())
        cluster_data[ids[i]] = neighbor[0].item()
    cluster_n, counts_n = np.unique(np.array(neighbors), return_counts=True)

    with open(cluster_save_path, "w") as f:
        json.dump(cluster_data, f, indent=4)

#################################### VISUALIZING FUCNTIONS ######################################

def get_variance(metrics):
    variance_vals = []
    for i in range(len(metrics)):
        attn_svd_trajectory = metrics[i].numpy()
        variance_val = np.var(attn_svd_trajectory)
        variance_vals.append(variance_val.item())
    return variance_vals 

def save_attn_svd_variance(metrics):
    variance_vals = []
    for i in range(len(metrics)):
        attn_svd_trajectory = metrics[i].numpy()
        variance_val = np.sum(np.abs(np.diff(attn_svd_trajectory)))
        variance_vals.append(variance_val.item())
    return variance_vals

def compute_ema(data, alpha=0.1):
    """
    Compute the exponential moving average (EMA) of a list of numbers.

    Args:
        data (list or array-like): The input data.
        alpha (float): The smoothing factor (between 0 and 1). 
                       A higher value gives more weight to recent data.
                       
    Returns:
        list: The EMA of the input data.
    """
    if not data:
        return []
    
    # Initialize the EMA with the first data point.
    ema = [data[0]]
    
    # Compute EMA for each subsequent element.
    for i in range(1, len(data)):
        new_ema = alpha * data[i] + (1 - alpha) * ema[i-1]
        ema.append(new_ema)
    
    return ema

def autocorrelation(signal, normalize=True):
    """
    Compute the autocorrelation of a 1D discrete signal.

    The autocorrelation is computed as:
      R(τ) = Σ_i (x[i] - μ) (x[i+τ] - μ)
    for lags τ = 0, 1, ..., n-1,
    where μ is the mean of the signal.

    Args:
        signal (array-like): Input 1D signal.
        normalize (bool): If True, normalize the autocorrelation so that R(0)=1.

    Returns:
        np.ndarray: The autocorrelation values for lags 0 to n-1.
    """
    # Convert the input signal to a NumPy array and subtract the mean
    signal = np.asarray(signal)
    signal = signal - np.mean(signal)
    
    # Compute the full autocorrelation (mode 'full' returns values for negative and positive lags)
    full_ac = np.correlate(signal, signal, mode='full')
    
    # The full autocorrelation has length 2*n-1, with zero lag at index n-1.
    # We return the autocorrelation for non-negative lags.
    ac = full_ac[full_ac.size // 2:]
    
    # Normalize such that the autocorrelation at lag 0 equals 1, if requested
    if normalize and ac[0] != 0:
        ac = ac / ac[0]
    
    return ac

def get_single_point_plot(vals, datapoint):
    checkpoints = list(np.arange(23))
    vals = list([i.item() for i in vals])
    tv_vals = [0]
    tv_vals.extend(list(np.diff(np.array(vals))))
    abs_tv_vals = np.abs(tv_vals)
    # ema_vals = compute_ema(vals, alpha=0.2)
    # ac_vals = autocorrelation(vals)
    plt.figure(figsize=(8, 5))
    plt.plot(checkpoints, vals, marker='o', linestyle='-', color='b')
    plt.plot(checkpoints, tv_vals, marker='o', linestyle='-', color='r')
    plt.plot(checkpoints, abs_tv_vals, marker='o', linestyle='-', color='g')
    # plt.plot(checkpoints, tv_vals, marker='o', linestyle='-', color='y')

    plt.xlabel('Checkpoints')
    plt.ylabel('Attention-SVD vals')
    plt.title(f'Plot of Attention SVD for datapoint: {datapoint}')
    plt.legend()
    plt.grid(True)
    save_path = f"/data/temp_trajectories/checking/attn_svd_{datapoint}.png"
    plt.savefig(save_path)

################################################################################################

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--metric", type=str, required=True)
    parser.add_argument("--metric_dir", type=str)
    parser.add_argument("--stability_save_dir_path", type=str)
    parser.add_argument("--k", type=int, default=5)
    args = parser.parse_args()

    main(args)
