import json
import torch
import argparse
import pickle
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from tqdm import tqdm
from vllm import LLM, SamplingParams
# from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel

def recursive_to_list(obj):
    if isinstance(obj, torch.Tensor):
        return obj.cpu().tolist()
    elif isinstance(obj, (list, tuple)):
        return [recursive_to_list(item) for item in obj]
    else:
        return obj
def generate_answer(model, input_strs, output_file):
    sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=512)
    llm = LLM(model=model, tensor_parallel_size=torch.cuda.device_count(), trust_remote_code=True)
    tokenizer = llm.get_tokenizer()
    outputs = llm.generate(input_strs, sampling_params)
    outputs = [output.outputs[0].text for output in outputs]
    data_json = []
    for input_text, output_text in zip(input_strs, outputs):
        single_json = {
            "input": input_text,
            "output": output_text
        }
        data_json.append(single_json)
    output_json = f"{output_file}.json"

    # Write the data to the JSON file
    with open(output_json, "w", encoding="utf-8") as f:
        json.dump(data_json, f, ensure_ascii=False, indent=4)

    # destroy_model_parallel()
    del llm.llm_engine.model_executor
    del llm
    # gc.collect()
    torch.cuda.empty_cache()
    print('Using VLLM, we do not need to set batch size!')


def save_hidden_states_for_dataset(input_texts, model_path, output_file, device='cuda'):


    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="right", use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="cuda:0", trust_remote_code=True).to(device)
    model.eval()


    last_results = {}
    mean_results = {}
    
    with torch.no_grad():
        for idx, input_text in tqdm(enumerate(input_texts), total=len(dataset), desc="Processing samples"):
            
            # input_text = f"{input_text} {prompt}"


            inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
            

            # 前向计算，输出 hidden states
            outputs = model(**inputs, output_hidden_states=True, return_dict=True)
            hidden_states = outputs.hidden_states

            # 只保留最后一个 token 的 hidden states（每层一个向量）
            last_token_hidden = [layer[0, -1, :].cpu() for layer in hidden_states]

            # 对所有 token 的隐藏状态做平均（每层一个向量）
            mean_token_hidden = [layer[0, :, :].mean(dim=0).cpu() for layer in hidden_states]
            # print(f"last_token_hidden is {len(mean_token_hidden)}")

            last_results[idx] = {
                "input": input_text,
                "hidden_states_last_token": recursive_to_list(last_token_hidden)
            }

            mean_results[idx] = {
                "input": input_text,
                "hidden_states_mean": recursive_to_list(mean_token_hidden)
            }

            # print(f"Processed sample {idx}: {input_text[:30]}...")
                # Specify the output file path

    with open(f"{output_file}.pkl", 'wb') as f:
        pickle.dump({
            "last_results": last_results,
            "mean_results": mean_results
        }, f)

    print(f"Saved last and mean hidden states for {len(dataset)} samples to {output_file}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Save compressed hidden states from model outputs.")
    parser.add_argument("--dataset_path", type=str, default="alpaca_dataset.json", help="Path to the Alpaca dataset JSON file.")
    parser.add_argument("--model_path", type=str, default="meta-llama/Llama-2-7bf", help="Path or name of the pretrained model.")
    parser.add_argument("--output_file", type=str, default="hidden_states_compressed.pkl", help="Output file to save results.")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g. 'cuda' or 'cpu').")

    args = parser.parse_args()

    with open(args.dataset_path, 'r', encoding='utf-8') as f:
        dataset = json.load(f)

    input_texts = []
    for sample in dataset:
        if sample.get("input", ""):
            input_texts.append(f'{sample.get("instruction", "")} {sample.get("input", "")}')

        else:
            input_texts.append(f'{sample.get("instruction", "")}')
    # input_texts = [sample.get("instruction", "") for sample in dataset]
    
    generate_answer(args.model_path, input_texts, args.output_file)

    save_hidden_states_for_dataset(
        input_texts=input_texts,
        model_path=args.model_path,
        output_file=args.output_file,
        device=args.device
    )
