import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# import laser.factor as factor
from csv import writer
import random
from copy import deepcopy

from fast_pytorch_kmeans import KMeans

from collections import namedtuple

Codebook = namedtuple('Codebook', ['centroids', 'labels'])

# Helper functions for abs weight pruning
def sorted_mat(matrix):
    temp = list(abs(matrix).flatten())
    temp.sort()
    return temp


def prune(matrix, mat_sort, to_prune):
    if to_prune != 0:
        alpha = mat_sort[int(to_prune * 0.1 * len(mat_sort))]
        matrix[abs(matrix) <= alpha] = 0
    return matrix


def rank(matrix):
    np_matrix = np.array(matrix)
    return np.linalg.matrix_rank(np_matrix)/min(list(np_matrix.shape))


# What percentage can be pruned by weight
def sparsity(matrix, alpha):
    abs_matrix = abs(matrix)
    filtered_matrix = abs_matrix[abs_matrix < alpha]
    return len(filtered_matrix)/matrix.size


def viz_rank_change(rank_list,name):
    fig = plt.figure()
    plt.plot(rank_list)
    plt.savefig(name)


# # Helper functions for rank reduction
# def do_low_rank(weight, k, num_clusters, group_via='rows', debug=True, niter=2):
#     assert weight.ndim == 2

#     max_rank = min(weight.shape[0], weight.shape[1])
#     desired_rank = int(max_rank * k)

#     if debug:
#         print(f"Shape is {weight.shape} and shape is {weight.dtype} => desired rank {desired_rank}")

#     # results = torch.svd_lowrank(weight,
#     #                             q=desired_rank,
#     #                             niter=niter)
#     # weight_approx = results[0] @ torch.diag(results[1]) @ results[2].T
        
#     results = torch.svd(weight)
#     feats = list(range(int(k*max_rank)))
#     u = results[0][:,feats]
#     s = torch.diag(results[1][feats])
#     vt = results[2].T[feats,:]
#     weight_approx = torch.matmul(torch.matmul(u,s),vt)
#     print("Desired rank: ", int(max_rank*k))

#     if debug:
#         print(f"New matrix has shape {weight_approx.shape}")

#     assert weight_approx.shape[0] == weight.shape[0] and weight_approx.shape[1] == weight.shape[1]
#     weight_approx = torch.nn.Parameter(weight_approx)

#     return weight_approx
    
def get_possible_k(w_shape_in):
    """Get possible k-splits for some in shape of weight."""
    # divisors are a numpy array in desceding order
    divisors = np.arange(w_shape_in, 0, -1)
    return divisors[np.remainder(w_shape_in, divisors) == 0]

def check_equilivance_of_k_1():
    arr = torch.rand(16384, 4096)

    weight_approx_mine =  do_low_rank(arr, 0.9, num_clusters = 1, group_via='rows', debug=True)
    weight_approx_mine_shiva = do_low_rank_working(arr, 0.9, debug=True)
    print("difff:", torch.max(weight_approx_mine - weight_approx_mine_shiva))
    print("===================================================================================")
    print("===================================================================================")
    print("===================================================================================\n\n\n\n\n")

def do_low_rank(weight, k, num_clusters, group_via='rows', debug=True, shuffle=False):
    assert weight.ndim == 2

    if group_via == 'cols':  weight = weight.T
    # weight.shape #out
    # print("Get Possible k's: ", get_possible_k(weight.shape[0]))
    n_rows = weight.shape[0]
    rows_per_group = int(n_rows/num_clusters)
    grouped_rows = torch.split(weight, rows_per_group)
    max_rank = min(grouped_rows[0].shape[0], weight.shape[1])
    desired_rank = int(max_rank * k)
    optimal_indices = None

    #shuffle
    if shuffle:
        print("Shuffling")
        random_indices = np.random.permutation(n_rows)
        weight = weight[random_indices]
        restore_indices = np.argsort(random_indices)

    grouped_rows = torch.split(weight, rows_per_group)

    max_rank = min(grouped_rows[0].shape[0], weight.shape[1])
    desired_rank = int(max_rank * k)
    desired_rank_2 = min(max_rank, int(max_rank * k / 2))
    
    if debug:
        print(f"Weights matrix shape is {weight.shape}")
        print(f"split to {len(grouped_rows)} groups")
        print(f"First/last group shape {grouped_rows[0].shape, grouped_rows[-1].shape}") #output this, and len(grouped_rows), desired_rank, k
        print(f"=> desired rank is {desired_rank}")

    weight_approx = []

    specific_idxs = False
    if specific_idxs:
        num_reduce = random.choice(np.arange(num_clusters, dtype=int)) + 1
        reduce_idxs = random.sample(np.arange(num_clusters, dtype=int).tolist(), num_reduce)
    else:
        reduce_idxs = np.arange(num_clusters, dtype=int).tolist()

    
    print("Indices being reduced: ", reduce_idxs)
    
    for idx, group in enumerate(grouped_rows):
        if idx == len(grouped_rows) - 1 and num_clusters < len(grouped_rows):
            weight_approx.append(group)
            print("Fine early exit")
            break   
        if idx in reduce_idxs:
            results = torch.svd(group)
            feats = list(range(int(desired_rank)))
            u = results[0][:,feats]
            s = torch.diag(results[1][feats])
            vt = results[2].T[feats,:]
            group_weight_approx = torch.matmul(torch.matmul(u,s),vt)
            weight_approx.append(group_weight_approx)
        else:
            results = torch.svd(group)
            feats = list(range(int(desired_rank_2)))
            u = results[0][:,feats]
            s = torch.diag(results[1][feats])
            vt = results[2].T[feats,:]
            group_weight_approx = torch.matmul(torch.matmul(u,s),vt)
            weight_approx.append(group_weight_approx)

    weight_approx = torch.cat(weight_approx)
    random_weight =  False
    if random_weight:
        min_val = torch.min(weight)
        max_val = torch.max(weight)
        weight_approx = torch.rand_like(weight)
        weight_approx = weight_approx * (max_val - min_val) + min_val
        
    #reshuffle weight_approx and weight
    if shuffle:
        print("Reordering Shuffle")
        weight_approx = weight_approx[restore_indices]
        weight = weight[restore_indices]
    print("Weight Shape: ", weight.shape)
    if group_via == 'cols':
        weight = weight.T
        weight_approx = weight_approx.T
    if debug:
        #print(f"New matrix has shape {weight_approx.shape}")
        #print(f"Rank = {desired_rank}, and clusters = {num_clusters}, {len(grouped_rows)}")
        print(f"Aprrox error = {torch.linalg.norm(weight_approx - weight, ord='fro')}")#out
        # print("---")
        max_row_err = 0
        max_idx = 0
        for i in range(weight.shape[0]):
            # print("Row " + str(i) + ":")
            row_err = torch.linalg.norm(weight_approx[i] - weight[i])
            # print(f"Aprrox error = {torch.linalg.norm(weight_approx[i] - weight[i])}")
            # print("---")
            if max_row_err < row_err:
                max_row_err = row_err
                max_idx = i
        print(f"Maximum row error is: {max_row_err}")
        print("in row: " + str(max_idx))

        print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n\n\n\n\n")  
    assert weight_approx.shape[0] == weight.shape[0] and weight_approx.shape[1] == weight.shape[1]
    # np.save('clustered_weight_layered_4', weight_approx.detach().cpu().numpy())
    weight_approx = torch.nn.Parameter(weight_approx)
    # print("Weight: ", weight)
    # print("Weight Approx: ", weight_approx)
    # weight_approx = torch.rand(weight.shape)
    # df_weight = pd.DataFrame(weight.detach().cpu().numpy())
    # df_weight.to_csv("original_weight.csv")

    # df_weight_approx = pd.DataFrame(weight_approx.detach().cpu().numpy())
    # df_weight_approx.to_csv("not_clustered_weight.csv")
    # np.save('original_weight_layer_4', weight.detach().cpu().numpy())
    
    return weight_approx, optimal_indices

def k_means_quantize(fp32_tensor: torch.Tensor, n_clusters=4, codebook=None):
    """
    quantize tensor using k-means clustering
    :param fp32_tensor:
    :param bitwidth: [int] quantization bit width, default=4
    :param codebook: [Codebook] (the cluster centroids, the cluster label tensor)
    :return:
        [Codebook = (centroids, labels)]
            centroids: [torch.(cuda.)FloatTensor] the cluster centroids
            labels: [torch.(cuda.)LongTensor] cluster label tensor
    """
    if codebook is None:
        kmeans = KMeans(n_clusters=n_clusters, mode='euclidean', verbose=0)
        labels = kmeans.fit_predict(fp32_tensor.view(-1, 1)).to(torch.long)
        centroids = kmeans.centroids.to(torch.float).view(-1)
        codebook = Codebook(centroids, labels)

    quantized_tensor = codebook.centroids[codebook.labels]
    # fp32_tensor.set_(quantized_tensor.view_as(fp32_tensor))
    return codebook, quantized_tensor

def do_quantize(weight, num_clusters, debug=True, optional_rate=0.0, optional_k_clusters=1):
    assert weight.ndim == 2
    
    if debug:
        print(f"Weights matrix shape is {weight.shape}")
        # print(f"split to {len(grouped_rows)} groups")
        # print(f"First/last group shape {grouped_rows[0].shape, grouped_rows[-1].shape}") #output this, and len(grouped_rows), desired_rank, k
        # print(f"=> desired rank is {desired_rank}")

    weight_approx = []
    if optional_rate > 0.0:
        weight, _ = do_low_rank(weight, 1.0-(0.1*optional_rate), optional_k_clusters)

    codebook, quantized_tensor = k_means_quantize(weight, num_clusters)
    print(f"Quantized tensor shape {torch.reshape(quantized_tensor, weight.shape).shape}")
    weight_approx = torch.nn.Parameter(torch.reshape(quantized_tensor, weight.shape))
    # weight_approx = torch.cat(torch.reshape(quantized_tensor, weight.shape))
        
    if debug:
        print(f"Aprrox error = {torch.linalg.norm(weight_approx - weight, ord='fro')}")#out
        max_row_err = 0
        max_idx = 0
        for i in range(weight.shape[0]):
            # print("Row " + str(i) + ":")
            row_err = torch.linalg.norm(weight_approx[i] - weight[i])
            # print(f"Aprrox error = {torch.linalg.norm(weight_approx[i] - weight[i])}")
            # print("---")
            if max_row_err < row_err:
                max_row_err = row_err
                max_idx = i
        print(f"Maximum row error is: {max_row_err}")
        print("in row: " + str(max_idx))

        print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n\n\n\n\n")  
    assert weight_approx.shape[0] == weight.shape[0] and weight_approx.shape[1] == weight.shape[1]
    weight_approx = torch.nn.Parameter(weight_approx)
    
    return weight_approx


def do_low_rank_working(weight, k, debug=True):

    assert weight.ndim == 2

    max_rank = min(weight.shape[0], weight.shape[1])
    desired_rank = int(max_rank * k)

    if debug:
        print(f"Shape is {weight.shape} and shape is {weight.dtype} => desired rank {desired_rank}")

    results = torch.svd(weight)
    feats = list(range(int(k*max_rank)))
    u = results[0][:,feats]
    s = torch.diag(results[1][feats])
    vt = results[2].T[feats,:]
    weight_approx = torch.matmul(torch.matmul(u,s),vt)
    print("Desired rank: ", int(max_rank*k))

    if debug:
        print(f"Aprrox error = {torch.linalg.norm(weight_approx - weight, ord='fro')}")
        print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n\n\n\n\n")
    assert weight_approx.shape[0] == weight.shape[0] and weight_approx.shape[1] == weight.shape[1]
    weight_approx = torch.nn.Parameter(weight_approx)
    
    df_weight = pd.DataFrame(weight)
    df_weight.to_csv("original_weight.csv")

    df_weight_approx = pd.DataFrame(weight)
    df_weight_approx.to_csv("clustered_weight.csv")
    return weight_approx