import torch
import pandas as pd
import numpy as np
import os
import gc
import argparse
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BitsAndBytesConfig

def create_model_and_tokenizer(model_name, device):
    model_name_hf = 'google/gemma-2b-it'
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float32
    )

    model = AutoModelForSequenceClassification.from_pretrained(
        model_name_hf,
        num_labels=1,
        use_safetensors=True,
        quantization_config=bnb_config,
        trust_remote_code=True,
        device_map={"": 0},
        low_cpu_mem_usage=True
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name_hf)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "right"
        model.config.pad_token_id = model.config.eos_token_id

    return model, tokenizer

def embed_subdf(device, sub_df, model, tokenizer, batch_size):
    config = model.config
    all_embeddings = []

    texts = sub_df['chat_text'].tolist()

    for i in tqdm(range((len(texts) + batch_size - 1) // batch_size), desc="Batch embedding"):
        start, end = i * batch_size, min((i + 1) * batch_size, len(texts))
        batch_texts = texts[start:end]

        batch_tokens = tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True)
        batch_tokens = {k: v.to(device) for k, v in batch_tokens.items()}

        input_ids = batch_tokens['input_ids']
        pad_token_id = config.pad_token_id
        seq_len = (torch.eq(input_ids, pad_token_id).int().argmax(-1) - 1).cpu().numpy()

        with torch.no_grad():
            output = model.model(**batch_tokens)[0]
            batch_embeddings = output[range(len(batch_texts)), seq_len].detach().cpu().numpy()
            all_embeddings.append(batch_embeddings)

        del batch_tokens
        del output
        del batch_embeddings
        torch.cuda.empty_cache()
        gc.collect()

    return np.vstack(all_embeddings)


def embed_and_save_split(device, df, model, tokenizer, X, N, output_path, final_output_path, dataset, batch_size=64, split_size=2000):
    os.makedirs(output_path, exist_ok=True)
    os.makedirs(final_output_path, exist_ok=True)
    sub_dir = os.path.join(output_path, "sub")
    os.makedirs(sub_dir, exist_ok=True)

    all_part_paths = []
    num_parts = (len(df) + split_size - 1) // split_size
    for part_idx in range(num_parts):
        sub_df = df.iloc[part_idx * split_size : min((part_idx + 1) * split_size, len(df))]
        print(f"Embedding part {part_idx+1}/{num_parts}, size={len(sub_df)}")

        sub_embeddings = embed_subdf(device, sub_df, model, tokenizer, batch_size=batch_size)

        part_save_path = os.path.join(sub_dir, f"{dataset}_part{part_idx}.npy")
        np.save(part_save_path, sub_embeddings)
        all_part_paths.append(part_save_path)

        print(f"✅ Saved part {part_idx} to {part_save_path}, shape={sub_embeddings.shape}")

    print("Concatenating all parts...")
    all_embeddings_reshape = np.vstack([np.load(p) for p in all_part_paths]).reshape(X, N, 2048)
    final_save_path_reshape = os.path.join(final_output_path, f"{dataset}_embedded_{X}.npy")
    np.save(final_save_path_reshape, all_embeddings_reshape)
    print(f"Final saved: {final_save_path_reshape}, shape={all_embeddings_reshape.shape}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', default='msmarco', type=str, help='Name of the dataset. Options: [\'msmarco\', \'nectar\']')
    parser.add_argument('-X', default=5000, type=int, help='Number of context')
    parser.add_argument('-N', default=100, type=int, help='Number of action')
    
    args = parser.parse_args()
    file_path = f"{args.d}_X{args.X}_N{args.N}"
    df = pd.read_csv(f"./data/original/{file_path}.tsv", sep="\t", engine="python")

    device = torch.device('cuda:0')
    model_name = 'Gemma'
    model, tokenizer = create_model_and_tokenizer(model_name, device)
    model.eval()

    embed_and_save_split(
        device, 
        df, 
        model, 
        tokenizer, 
        args.X, args.N,
        output_path="./data/embed", 
        final_output_path="./output", 
        dataset=args.d, 
        batch_size=32,
        split_size=10000
    )

    del model
    del tokenizer
    torch.cuda.empty_cache()