import os
import json
import random
from copy import copy
import torch
from torch.nn.utils.rnn import pad_sequence
import pytorch_lightning as pl

from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader, Dataset
import os
import numpy as np

from .config import config
from .data import SimilaritySpaceDataset
from .model import SimilaritySpaceModel

import os
import torch
import numpy as np
from tqdm import tqdm
import lmdb
import pickle

def get_all_llama_outputs(split):
    # Clear out the gpu completely
    torch.cuda.empty_cache()
    
    model = SimilaritySpaceModel().eval().to(config.get("training", "device"))

    # Ensure the LLAMA model is frozen
    for param in model.encoder_model.parameters():
        param.requires_grad = False

    # Set model to evaluation mode
    model.eval()

    # Define the path for the LMDB database
    lmdb_path = config.get("data_gen", "llama_precompute_path")
    os.makedirs(lmdb_path, exist_ok=True)

    # Create an LMDB environment
    env = lmdb.open(
        lmdb_path,
        map_size=1e12,  # Adjust as needed (1 TB here)
        max_dbs=3,
        readonly=False,
        lock=True,
    )

    # Open databases for anchors and augmented embeddings
    embedding_db = env.open_db(b'embeddings')
    
    logged_keys_this_run = set()

    dataset = SimilaritySpaceDataset(split)

    for i in tqdm(range(len(dataset)), desc=f"Processing {split} split"):
        current_row = dataset[i]

        # Process anchor embedding
        anchor_ids = current_row["anchor_ids"]
        augmented_ids = current_row["augmented_ids"]
        
        anchor_attention_mask = current_row["anchor_attention_mask"]
        
        # Get the key values
        anchor_key = current_row["anchor_key"].tolist()
        augmented_keys = current_row["augmented_keys"].tolist()

        # Create a unique key for the anchor sequence
        anchor_key_bytes = generate_key(anchor_key[0])

        if not anchor_key_bytes in logged_keys_this_run:
            with env.begin(write=True, db=embedding_db) as txn:
                if txn.get(anchor_key_bytes) is None:
                    # Compute embedding
                    anchor_embedding = compute_embedding(
                        model, anchor_ids, anchor_attention_mask
                    )
                    # Serialize and store the embedding
                    txn.put(anchor_key_bytes, pickle.dumps(anchor_embedding))
                else:
                    # Embedding already exists
                    #print(f"Embedding {nifty_id} already exists")
                    pass
            logged_keys_this_run.add(anchor_key_bytes)

        # Process augmented embeddings
        augmented_ids_list = current_row["augmented_ids"]
        augmented_attention_mask_list = current_row["augmented_attention_mask"]

        for j in range(len(augmented_ids_list)):
            augmented_key = augmented_keys[j]
            
            augmented_ids = augmented_ids_list[j]
            augmented_attention_mask = augmented_attention_mask_list[j]

            # Create a unique key for the augmented sequence
            augmented_key_byte = generate_key(augmented_key)
            
            if not augmented_key_byte in logged_keys_this_run:
                with env.begin(write=True, db=embedding_db) as txn:
                    if txn.get(augmented_key_byte) is None:
                        # Compute embedding
                        augmented_embedding = compute_embedding(
                            model, augmented_ids, augmented_attention_mask
                        )
                        # Serialize and store the embedding
                        txn.put(augmented_key_byte, pickle.dumps(augmented_embedding))
                    else:
                        # Embedding already exists
                        #print(f"Embedding {nifty_id}_{j} already exists")
                        pass
                logged_keys_this_run.add(augmented_key_byte)

    # Close the LMDB environment
    env.close()
    
    # Clear out the gpu completely
    torch.cuda.empty_cache()

def compute_embedding(model, input_ids, attention_mask):
    input_ids = input_ids.unsqueeze(0).to(config.get("training", "device"))
    attention_mask = attention_mask.unsqueeze(0).to(config.get("training", "device"))

    with torch.no_grad():
        outputs = model.encoder_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )
        last_hidden_state = outputs.hidden_states[-1]
        # Mean pooling with attention mask
        embedding = mean_pooling(last_hidden_state, attention_mask)
    return embedding.cpu().numpy()

def mean_pooling(last_hidden_state, attention_mask):
    # Exclude padding tokens from the mean
    attention_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    sum_embeddings = torch.sum(last_hidden_state * attention_mask, dim=1)
    sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
    mean_embeddings = sum_embeddings / sum_mask
    return mean_embeddings.squeeze(0)

import hashlib

def retrieve_embedding(env, db_name, key):
    with env.begin(write=False, db=env.open_db(db_name.encode('ascii'))) as txn:
        data = txn.get(key)
        if data is not None:
            return pickle.loads(data)
        else:
            return None  # Or handle the missing case as needed

def generate_key(key_int: int):
    """
    Generates a unique key based on nifty_id and augmented_idx.

    Parameters:
    - nifty_id: The unique ID string for the sample.
    - augmented_idx: The index of the augmented sample. Use None for the anchor.

    Returns:
    - key: The generated key as bytes, suitable for use in LMDB.
    """
    key = hashlib.sha256(str(key_int).encode('utf-8')).hexdigest().encode('ascii')
    return key
