import argparse
from pathlib import Path
import pickle as pkl
from typing import List

import pandas as pd
from tqdm import tqdm
from torch import cuda
from logzero import logger

from fastchat.model import load_model
import sys
sys.path.append("..")
from transformer.model import decode_tokens, generate_response, get_latent_representations
import torch 

# Main function to generate responses and save latent representations
def main(
    prompts: List[str],
    model_directory: str,
    save_directory: str = ".",
    num_additional_tokens: int = 100,
    layer_indices: List[int] = [-1],
    temperature: float = 0
):
    """
    Args:
        prompts (List[str]): List of input prompts to generate responses for.
        model_directory (str): Path to the directory containing the model checkpoint.
        save_directory (str, optional): Path to the directory where results will be saved. Defaults to ".".
        num_additional_tokens (int, optional): Number of additional tokens to generate in responses. Defaults to 100.
        layer_indices (List[int], optional): Indices for which layers' latent representations to save. Defaults to [-1] (last layer).
        temperature (float, optional): Temperature for token generation. Defaults to 0.
    """
    # Load the model and tokenizer within the main function
    logger.info("Loading tokenizer")
    model, tokenizer = load_model(
        model_directory,
        device="cuda",
        num_gpus=cuda.device_count(),
        load_8bit=False,
        cpu_offloading=True,
        max_gpu_memory='41Gib'
    )
    logger.info("Loading model weights")

    responses = []
    embeddings = []
    attention_matrices = []
    response_token_lengths = []

    logger.info("Encoding prompts")
    for prompt in tqdm(prompts):  # Iterate through each prompt
        # Generate response using the model
        token_ids=generate_response(
                tokenizer=tokenizer,
                model=model,
                text=prompt,
                temperature=temperature,
                max_new_tokens=num_additional_tokens,
            )
        output = decode_tokens(
            tokenizer=tokenizer,
            token_ids=token_ids
        )
        response = output[len(prompt):] # Extract the generated response
        responses.append(response)
        prompt_token_count = tokenizer(prompt, return_tensors='pt').input_ids.shape[1]
        response_token_lengths.append(len(token_ids[0].tolist()) - prompt_token_count)
        # Get latent representations (attention matrices and embeddings)
        _attention_matrices, _embeddings = get_latent_representations(
            tokenizer=tokenizer,
            model=model,
            text=output,
            layer_indices=layer_indices,
            _attention_weights=True,
            _token_embeddings=True,
        )

        attention_matrices.append(_attention_matrices)
        embeddings.append(_embeddings)

        del _attention_matrices
        del _embeddings
        torch.cuda.empty_cache()

    # Create the save_directory if it doesn't exist along with any parent directories
    logger.info("Saving responses and latent representations")
    Path(save_directory).mkdir(parents=True, exist_ok=True)

    # Save responses and latent representations in pickle files
    with open(Path(save_directory, "attention_matrices.pkl"), "wb") as f:
        pkl.dump(attention_matrices, f)

    with open(Path(save_directory, "embeddings.pkl"), "wb") as f:
        pkl.dump(embeddings, f)

    with open(Path(save_directory, "responses.pkl"), "wb") as f:
        pkl.dump(responses, f)

    with open(Path(save_directory, "response_lengths.pkl"), "wb") as f:
        pkl.dump(response_token_lengths, f)


if __name__ == "__main__":
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Generate responses using a language model.")
    parser.add_argument("--prompt_file_path", type=str, help="Path to the tsv file containing prompts.")
    parser.add_argument("--model_directory", type=str, help="Path to the directory containing the model checkpoint.")
    parser.add_argument("--save_directory", type=str, default=".",
                        help="Path to the directory where results will be saved.")
    parser.add_argument("--number_of_completion_tokens", type=int, default=100,
                        help="Number of completion tokens to generate.")
    parser.add_argument("--model_temperature", type=float, default=0.0, help="Temperature for token generation.")
    parser.add_argument("--layer_indices", type=str, default="-1", help="Indices for which layers latent representations to save. Defaults to last layer.")
    parser.add_argument("--input_column_name",  type=str, default=None, help="Name of the column containing prompts.")
    args = parser.parse_args()
    args.layer_indices = [int(i) for i in args.layer_indices.split(",")]
    print(args)

    # Load the prompts from the CSV file
    logger.info("Loading prompts")

    if args.input_column_name is not None:
        data = pd.read_csv(args.prompt_file_path, sep="\t")
        prompts = data[args.input_column_name].tolist()
    else:
        data = pd.read_csv(args.prompt_file_path, sep="\t", header=None)
        prompts = data[0].tolist()

    # Run the main logic with the provided arguments
    main(
        prompts=prompts,
        model_directory=args.model_directory,
        save_directory=args.save_directory,
        num_additional_tokens=args.number_of_completion_tokens,
        temperature=args.model_temperature,
        layer_indices=args.layer_indices
    )
