# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import torch
import sys
import os
import argparse

import shutil
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import copy

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(dir_path)
from reshape import *
from tqdm import trange
import torch.nn.functional as F

import torch.nn as nn



def get_data_in_chunks(X, batch_size):
    for i in range(0, len(X), batch_size):
        yield X[i:i+batch_size], i

def common_kmeans_batch(X, num_clusters, tolerance, batch_size, salience=None):    
    if salience is not None:
        _, indices = torch.topk(salience.sum(-1), num_clusters)
        centroids = X[indices]
    else:
        indices = np.random.choice(len(X), num_clusters, replace=False)
        centroids = X[indices]
    
    cluster_assignments = torch.zeros(len(X), dtype=torch.int32, device=X.device)
    
    iters = 0

    while True:
        old_assignments = cluster_assignments.clone()

        for data, start_idx in get_data_in_chunks(X, batch_size):
            distances = torch.cdist(data, centroids)
            new_assignments = torch.argmin(distances, dim=1)
            cluster_assignments[start_idx:start_idx+data.shape[0]] = new_assignments

        diff_num = torch.sum(cluster_assignments != old_assignments)
        if diff_num < tolerance * len(X):
            break

        iters += 1
        if iters > 100 * X.shape[-1]:
            print("=================== too many iters ===================")
            break

        if salience is None:
            for i in range(num_clusters):
                cluster_samples = X[cluster_assignments == i]
                if len(cluster_samples) > 0:
                    centroids[i] = cluster_samples.mean(dim=0)

        else:
            for i in range(num_clusters):
                cluster_samples = X[cluster_assignments == i]
                cluster_salience = salience[cluster_assignments == i]
                if len(cluster_samples) > 0:
                    centroids[i] = torch.sum(cluster_samples*cluster_salience.sum(-1).unsqueeze(-1), dim=0)/torch.sum(cluster_salience.sum(-1).unsqueeze(-1), dim=0)
        
   
    return centroids, cluster_assignments



def codebook(base_module, Cmodule, sub_path, 
                    n_centroids, vector_len, train_new, salience=None):    
    origin_weight = base_module.weight.data.detach()
    sizes = base_module.weight.size()
    out_features, in_features = sizes 
    n_centroids = n_centroids
    vector_len = vector_len
    reshape_weight_error = None
    

    if train_new:


        reshape_weight = reshape_weightlike_cin(origin_weight, vector_len)            
        total_block = reshape_weight.to(torch.device('cuda:0'))
        common_kmeans = common_kmeans_batch

        centroids, assignments, reshape_weight_error = common_kmeans(total_block, n_centroids, 
                                                     tolerance=0.3, batch_size=131072//16, salience=salience)
    
        del reshape_weight
        del reshape_weight_error
        del total_block
        del salience

        centroids = centroids.to(origin_weight.dtype)
        assignments = assignments.to(torch.int32)
        torch.save(centroids, os.path.join(sub_path, Cmodule+'_centroids.pth'))
        torch.save(assignments, os.path.join(sub_path, Cmodule+'_assignments.pth'))
    else:
        centroids   = torch.load(os.path.join(sub_path, Cmodule+'_centroids.pth'), map_location=torch.device('cuda'))
        assignments = torch.load(os.path.join(sub_path, Cmodule+'_assignments.pth'), map_location=torch.device('cuda'))

    centroids = centroids.to(origin_weight.dtype)
    assignments = assignments.to(torch.int32)
    torch.save(centroids, os.path.join(sub_path, Cmodule+'_centroids.pth'))
    torch.save(assignments, os.path.join(sub_path, Cmodule+'_assignments.pth'))
   
    reorder_weight1 = centroids[assignments.to(torch.int32)]
    reorder_weight1 = reshape_back_weight_cin(reorder_weight1, vector_len, in_features)
    reorder_weight = reorder_weight1.to(origin_weight.dtype)

    print(Cmodule)

    base_module.weight.data = reorder_weight.data.detach()
    if origin_weight.dtype == torch.int8:
        base_module.weight.CB.data = reorder_weight.data.detach()

    del reorder_weight
    del centroids
    del assignments
    del origin_weight
    
    return