"""
General purpose methods to cluster any dataset
"""

import os
import time
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from typing import Callable
from tqdm import tqdm
from .utils import process_batch

def kmeans_clustering(dataset: Dataset, num_clusters: int, 
                      device: str = 'cpu', whiten = True) -> torch.Tensor:
    """Perform k-means clustering (code from: https://github.com/subhadarship/kmeans_pytorch)"""
    if num_clusters == len(dataset):
        return torch.arange(len(dataset))
    
    data = torch.concatenate([dataset[i][0].flatten().unsqueeze(0) for i in range(len(dataset))])
    data = whitening(data) if whiten else data

    cluster_ids_x, _ = kmeans(
        X=data, num_clusters=num_clusters, device=device
    )
    return cluster_ids_x
  
def get_ids(centroids, testset, device='cpu'):
    """
    Get the cluster ids for the testset
    """
    test_loader = DataLoader(testset, batch_size=512, shuffle=False)
    cluster_ids = torch.zeros(len(testset), dtype=torch.long)
    print(centroids)
    print(centroids.shape)
    for i, x in enumerate(test_loader):
        dis = torch.cdist(x.to(device), centroids.to(device))
        cluster_ids[i*512:(i+1)*512] = torch.argmin(dis, dim=1)
    return cluster_ids
  

def get_representations(dataset, model, batch_size=512, device='cpu'):
    """
    Get the data representations
    """
    if dataset.name in ['qnli', 'qnli_noisy']:
        from transformers.data.data_collator import default_data_collator
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator)
    else:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    data_representations = []
    with torch.no_grad():
        for batch in dataloader:
            features, _ = process_batch(batch, device)
            data_representation = model.get_data_representation(**features) # penultimate layer
            data_representations.append(data_representation)
    data_representations = torch.concatenate(data_representations)
    return data_representations
   
def repr_kmeans_clustering(dataset: Dataset, num_clusters: int, model: nn.Module, 
                            batch_size:int = 512, device: str = 'cpu', testset: Dataset = None) -> torch.Tensor:
    if num_clusters == len(dataset) and testset is None:
      return torch.arange(len(dataset))

    # Get the data representation
    model = model.to(device)
    data_representations = get_representations(dataset, model, batch_size, device)
    if testset is not None:
        test_data_representations = get_representations(testset, model, batch_size, device)
        if num_clusters == len(dataset): 
          return torch.arange(len(dataset)), get_ids(data_representations, test_data_representations, device)

    print(f'data_representations shape: {data_representations.shape}')
    cluster_ids_x, centroids = kmeans(
        X=data_representations, num_clusters=num_clusters, device=device
    )
    if testset is not None:
        return cluster_ids_x, get_ids(centroids, test_data_representations, device)
    return cluster_ids_x

def equal_clustering(dataset: Dataset, group_size: int) -> torch.Tensor:
    """
    Split the dataset into groups of equal size
    Return cluster ids
    """
    # Split the dataset into groups of equal size
    num_groups = len(dataset) // group_size
    cluster_ids = torch.arange(num_groups).repeat_interleave(group_size)

    # Account for the remainder
    diff = len(dataset) - len(cluster_ids)
    if diff > 0:
        final_group = torch.ones(diff, dtype=torch.long) * num_groups
        cluster_ids = torch.cat((cluster_ids, final_group))

    return cluster_ids

def get_data_grads(dataset: Dataset, model: nn.Module, 
                   loss_fn: Callable, batch_size:int = 512, 
                   device: str = 'cpu', model_dir: str = None, 
                   use_grad_cache: bool = True,verbose: bool = True) -> torch.Tensor:

    if use_grad_cache and model_dir is not None:
        grad_path = os.path.join(model_dir, 'grads.pt')
        if os.path.exists(grad_path):
            if verbose:
                print("Model grads exist, loading from", grad_path)
            return torch.load(grad_path)

    start_time = time.time()
    model = model.to(device)
    feature_grads = []
    def _extract_layer_grads(module, in_grad, out_grad):
        # function to collect the gradient outputs
        # from each layer
        return feature_grads.append(in_grad[0].detach().cpu())

    # get the last linear layer
    for m in model.modules():
        if isinstance(m, nn.Linear):
            layer = m
                
    # Register feature-gradient hook for last layer
    hook = layer.register_full_backward_hook(_extract_layer_grads)
    
    if dataset.name in ['qnli', 'qnli_noisy']:
        from transformers.data.data_collator import default_data_collator
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator)
    else:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    for data in tqdm(dataloader):
        inputs, labels = process_batch(data, device=device)
        if "x" in inputs: inputs["x"].requires_grad_()
        model.zero_grad()
        objective = loss_fn(model(**inputs), labels)
        objective.backward()

    outputs = torch.cat(feature_grads)
    print("Time taken to compute gradients:", time.time() - start_time)

    if use_grad_cache and model_dir is not None:
        torch.save(outputs, grad_path)
    
    return outputs

def whitening(inputs: torch.Tensor, epsilon: float = 1e-6):
    # input dimensions: [# samples, # dims]
    covariance_mat = inputs.T @ inputs
    u,s,v = torch.svd(covariance_mat)
    whitening_matrix = u @ torch.diag(1/torch.sqrt(s + epsilon)) @ v.T 
    return inputs @ whitening_matrix 

def grad_kmeans_clustering(dataset: Dataset, num_clusters: int, model: nn.Module,
                           loss_fn: Callable, batch_size:int = 512, device: str = 'cpu', model_dir: str = None) -> torch.Tensor:
    
    """"
    Clustering based on gradients wrt last layer activations
    """

    if num_clusters == len(dataset):
        return torch.arange(len(dataset))

    # gradients = get_data_grads(dataset, model, loss_fn, batch_size, device)
    gradients = get_data_grads(dataset, model, loss_fn, batch_size, device, model_dir)

    gradients = gradients.to(device)
    whitened_gradients = whitening(gradients)

    cluster_ids_x, _ = kmeans(
        X=whitened_gradients, num_clusters=num_clusters, device=device
    )
    return cluster_ids_x

def kmeans(
        X,
        num_clusters,
        tol=1e-3,
        max_iters=60,
        device='cpu',
):
    """
    perform kmeans
    :param X: (torch.tensor) matrix
    :param num_clusters: (int) number of clusters
    :param tol: (float) threshold [default: 0.0001]
    :param device: (torch.device) device [default: cpu]
    :return: (torch.tensor, torch.tensor) cluster ids, cluster centers
    """
    print(f'running k-means on {device}..')
    
    # convert to float
    X = X.float()

    # transfer to device
    X = X.to(device)

    print("X shape:", X.shape)
    # print norm of X
    print("Norm of X:", torch.norm(X, dim=1).mean())

    # initialize
    num_samples = len(X)
    indices = np.random.choice(num_samples, num_clusters, replace=False)
    initial_state = X[indices]
    print("Initial State Shape:", initial_state.shape)
    print("Initial State Device:", initial_state.get_device())

    iteration = 0
    while True:
        # if device == 'cpu':
        #     dis = pairwise_distance(X, initial_state)
        #     choice_cluster = torch.argmin(dis, dim=1)
        # else:
        torch.cuda.empty_cache()
        choice_cluster = min_pairwise_distance_idxs(X, initial_state, device=device)

        initial_state_pre = initial_state.clone()

        for index in range(num_clusters):
            selected_idxs = torch.nonzero(choice_cluster == index).squeeze().to(device)

            selected = torch.index_select(X, 0, selected_idxs)
            if selected.shape[0] != 0:
                initial_state[index] = selected.mean(dim=0)

        center_shift = torch.sum(
            torch.sqrt(
                torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
            ))
        
        print(f'iteration {iteration}, center shift: {center_shift}')

        # increment iteration
        iteration = iteration + 1

        if center_shift ** 2 < tol or iteration >= max_iters:
            break

    return choice_cluster.cpu(), initial_state.cpu()

def pairwise_distance(data1, data2, device=torch.device('cpu')):
    # transfer to device
    data1, data2 = data1.to(device), data2.to(device)

    # N*1*M
    A = data1.unsqueeze(dim=1)

    # 1*N*M
    B = data2.unsqueeze(dim=0)

    dis = (A - B) ** 2.0
    # return N*N matrix for pairwise distance
    dis = dis.sum(dim=-1).squeeze()
    return dis

def min_pairwise_distance_idxs(data1, data2, device='cuda'):
    print("Data1 Shape:", data1.shape)
    print("Data2 Shape:", data2.shape)
    data2 = data2.to(device).unsqueeze(0)

    if device == 'cuda':
      total_memory = torch.cuda.get_device_properties(device).total_memory / (1024**2)
      available_memory = total_memory - torch.cuda.memory_reserved(device) / (1024**2)
      memory_buffer = 2048  # 2 GB buffer
      element_size_bytes = 4
      memory_per_element = data2.numel() * element_size_bytes / (1024**2)
      batch_size = int((available_memory - memory_buffer) // memory_per_element) // 4 # safety factor
      print("Batch Size:", batch_size)
    else:
      batch_size = 32
    loader = DataLoader(data1, batch_size=batch_size, shuffle=False)
    min_idxs = torch.zeros(data1.shape[0], dtype=torch.long).to('cpu')
    # start_time = time.time()
    # print("Start Distance Computation with batch size:", batch_size)
    for i, x in enumerate(loader):
        x = x.to(device).unsqueeze(1)
        dis = (x-data2)**2
        dis = dis.sum(dim=-1)
        m_idx = torch.argmin(dis, dim=1)
        min_idxs[i*batch_size:(i+1)*batch_size] = m_idx.detach().cpu()
    # print("Time taken:", time.time() - start_time)
    return min_idxs


def initialize(X, num_clusters):
    """
    initialize cluster centers
    :param X: (torch.tensor) matrix
    :param num_clusters: (int) number of clusters
    :return: (np.array) initial state
    """
    num_samples = X.shape[1]
    bs = X.shape[0]

    indices = torch.empty(X.shape[:-1], device=X.device, dtype=torch.long)
    for i in range(bs):
        indices[i] = torch.randperm(num_samples, device=X.device)
    initial_state = torch.gather(X, 1, indices.unsqueeze(-1).repeat(1, 1, X.shape[-1])).reshape(bs, num_clusters, -1, X.shape[-1]).mean(dim=-2)
    return initial_state

def kmeans_equal(
        X,
        num_clusters,
        cluster_size,
        max_iters=100,
        update_centers=True,
        device='cpu',
        tol=1e-4):
    """
    perform kmeans on equally sized clusters
    :param X: (torch.tensor) matrix
    :param num_clusters: (int) number of clusters
    :param max_iters: maximum iterations allowed (controls speed)
    :param initial_state: controls initial cluster centers. If none, forgy initialization is used.
    :param update_centers: if False, then it runs one iteration to assign points in clusters.
    :param tol: (float) threshold [default: 0.0001]
    :return: (torch.tensor, torch.tensor) cluster ids, cluster centers
    """
    print(f'running k-means equal on {device}..')
    
    # convert to float
    X = X.float()

    # transfer to device
    X = X.to(device)

    print("X shape:", X.shape)
    # print norm of X
    print("Norm of X:", torch.norm(X, dim=1).mean())

    num_samples = len(X)
    indices = np.random.choice(num_samples, num_clusters, replace=False)
    initial_state = X[indices]
    # initial_state = initialize(X, num_clusters)
    print("Initial State Shape:", initial_state.shape)
    print("Initial State Device:", initial_state.get_device())

    iteration = 0

    while True:
        dis = pairwise_distance(X, initial_state)
        print("Dis shape:", dis.shape)
        choices = torch.argsort(dis, dim=-1)
        print("Choices shape:", choices.shape)
        initial_state_pre = initial_state.clone()
        for index in range(num_clusters):
            cluster_positions = torch.argmax((choices == index).to(torch.long), dim=-1)
            print("Cluster Positions Shape:", cluster_positions.shape)
            selected_ind = torch.argsort(cluster_positions, dim=-1)[:cluster_size]
            print("Selected Ind Shape:", selected_ind.shape)

            choices.scatter_(1, selected_ind.unsqueeze(-1).repeat(1, num_clusters), value=index)
            # update cluster center

            if update_centers:
                initial_state[:, index] = torch.gather(X, 1, selected_ind.unsqueeze(-1).repeat(1, 1, X.shape[-1])).mean(dim=-2)


        center_shift = torch.sum(
            torch.sqrt(
                torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
            ))

        # increment iteration
        iteration += 1

        if center_shift ** 2 < tol:
            break
        if iteration >= max_iters:
            break

    return choices[:, :, 0], initial_state