import datasets
from dateutil import parser
from sklearn.externals.array_api_compat.numpy import require
from transformers import AutoTokenizer, AutoModel, data
import torch
from sklearn.cluster import MiniBatchKMeans
# from sklearn.preprocessing import StandardScaler
import argparse
import pandas as pd
import numpy as np
from tqdm import tqdm
from multiprocessing import Pool
from functools import partial
import os
from SampleEfficiencyMatrix import SampleEfficiencyMatrix

os.environ['OPENBLAS_NUM_THREADS'] = '64'

def get_embeddings(texts, tokenizer, model):
    model.to("cuda")
    model.eval()
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]  # last hidden states
        attention_mask = inputs["attention_mask"].unsqueeze(-1)
    # mean pooling
    sentence_embedding = (hidden_states * attention_mask).sum(1) / attention_mask.sum(1)
    return sentence_embedding.cpu().numpy()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, required=True)
    parser.add_argument("--embedding_path", type=str, required=True)
    parser.add_argument("--cluster_path", type=str)
    parser.add_argument("--weight_path", type=str, required=True)
    parser.add_argument("--model_path", type=str, default="")
    parser.add_argument("--num_proc", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--kmeans", type=bool, default=True)

    args = parser.parse_args()

    model_path = args.model_path
    data_path = args.data_path
    embedding_path = args.embedding_path
    cluster_path = args.cluster_path

    dataset = pd.read_parquet(data_path)

    print(dataset)

    if not os.path.exists(embedding_path):
        problems = [dataset['prompt'][i][0]['content'] for i in range(len(dataset))]
        print(f"problems prepared: {len(problems)}")
        def local_embeddings(args):
            pid, local_problems, model_path = args
            tokenizer = AutoTokenizer.from_pretrained(model_path)
            model = AutoModel.from_pretrained(model_path).to("cuda")
            local_embeddings = []
            for i in tqdm(range(len(local_problems)), desc=f"Process {pid}"):
                embedding = get_embeddings(local_problems[i], tokenizer, model)
                local_embeddings.append(embedding)
            return pid, np.array(local_embeddings)

        proc_batch_items = []
        proc_batch_size = (len(problems) + args.num_proc - 1) // args.num_proc
        for pid in range(args.num_proc):
            starg = pid * proc_batch_size
            end = min((pid + 1) * proc_batch_size, len(problems))
            proc_batch_items.append(
                (pid, problems[starg:end], model_path)
            )

        assert len(proc_batch_items)==args.num_proc
        with Pool(args.num_proc) as pool:   
            results = list(tqdm(pool.imap(local_embeddings, proc_batch_items), total=args.num_proc))
        
        results.sort(key=lambda x: x[0])
        embeddings = []
        for pid, emb in results:
            emb = emb.squeeze()
            print(f"pid: {pid}, emb: {emb.shape}")
            embeddings.extend(emb)
        embeddings = np.array(embeddings)
        print(f"embeddings done: {len(embeddings)}")
        np.save(embedding_path, embeddings)
        print(f"embeddings saved: {args.embedding_path}")
    else:
        embeddings = np.load(args.embedding_path)

    num_batches = (len(embeddings) + args.batch_size - 1) // args.batch_size

    if args.kmeans:
        kmeans = MiniBatchKMeans(
            n_clusters=int(np.sqrt(num_batches)),
            batch_size=args.batch_size,
            max_iter=100,
            random_state=42
        )
        print("kmeans fitting...")
        kmeans.fit(embeddings)
        labels = kmeans.labels_
        # TODO sort dataset clusters by labels
        dataset["labels"] = labels
        dataset = dataset.sort_values(by="labels")
        dataset.reset_index(drop=True)
        dataset.to_parquet(cluster_path)

        embeddings = embeddings[dataset.index]
        print("sorted dataset by kmeans labels")

    from sklearn.metrics.pairwise import cosine_similarity
    # dataset = dataset.sample(frac=1, random_state=42)
    weights = np.zeros((num_batches, num_batches))
    np.fill_diagonal(weights, 1)
    embeddings_tensor = torch.tensor(embeddings, device='cuda')

    for i in tqdm(range(num_batches), desc="Calculating similarities (GPU)"):
        start_i = i * args.batch_size
        end_i = min((i + 1) * args.batch_size, len(embeddings))
        batch_i = embeddings_tensor[start_i:end_i]
        
        for j in range(i + 1, num_batches):
            start_j = j * args.batch_size
            end_j = min((j + 1) * args.batch_size, len(embeddings))
            batch_j = embeddings_tensor[start_j:end_j]
            
            # Normalize batches (L2 norm)
            batch_i_norm = torch.nn.functional.normalize(batch_i, p=2, dim=1)
            batch_j_norm = torch.nn.functional.normalize(batch_j, p=2, dim=1)
            
            # Compute cosine similarity matrix
            sim_matrix = torch.mm(batch_i_norm, batch_j_norm.T)
            
            # Calculate mean similarity and move to CPU
            weights[i, j] = torch.mean(sim_matrix).cpu().item()
            weights[j, i] = weights[i, j]  # Symmetric matrix

    np.save(args.weight_path, weights)
  