import os
import torch
import json 
import argparse
from tqdm import tqdm 
from vllm import SamplingParams, LLM
from transformers import AutoModelForCausalLM, AutoTokenizer

def generate_sample_with_transformers(
    pretrained_model_name_or_path,
    tokenizer_name_or_path,
    soft_prompt_dir,
    public_soft_token_count,
    temp,
    target_count,
    device,
):
    pre_trained_llm = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)
    pre_trained_llm.to(device)
    pre_trained_llm.eval()

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
    if not tokenizer.pad_token_id:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        
    public_soft_token_embeddings_path = os.path.join(soft_prompt_dir, 'public_soft_token_embeddings.pth')
    public_soft_token_embeddings_state_dict = torch.load(public_soft_token_embeddings_path, map_location=device)  
    public_soft_tokens_embeddings = torch.nn.Embedding(public_soft_token_count, pre_trained_llm.config.hidden_size)
    public_soft_tokens_embeddings.load_state_dict(public_soft_token_embeddings_state_dict)
    public_soft_tokens_embeddings.to(device).eval()

    soft_token_id = torch.arange(public_soft_token_count).to(device)
    public_soft_token_embeds = public_soft_tokens_embeddings(soft_token_id)  # [soft_token_count, hidden_size]
    public_soft_token_embeds = public_soft_token_embeds.view(1, public_soft_token_count, pre_trained_llm.config.hidden_size)
    print(f"soft token embeds shape = {public_soft_token_embeds.shape}")

    soft_attention_mask = torch.ones(1, public_soft_token_count).to(device)

    generated_samples = []
    with torch.no_grad():
        for current in tqdm(range(target_count), desc="Synthetic Samples..."):
            generated_outputs = pre_trained_llm.generate(
                inputs_embeds=public_soft_token_embeds,
                attention_mask=soft_attention_mask,
                max_new_tokens=2048,  
                do_sample=True,      
                temperature=temp,      
                use_cache=False,
                eos_token_id=[tokenizer.eos_token_id, tokenizer.pad_token_id],
                pad_token_id=tokenizer.pad_token_id,
                repetition_penalty=1.0
            )

            generated_text = tokenizer.decode(generated_outputs[0], skip_special_tokens=True) # set False to output <|endoftext|>

            # Print Synthetic Samples for Visible.
            print("****Generated Text:****")
            print(generated_text)
            print("=="*30)


            generated_samples.append(dict(idx=current, generated_text=generated_text))
    

    # dir_path = os.path.join(soft_prompt_dir, "synthetic_generate")
    # if not os.path.isdir(dir_path):
    #     os.makedirs(dir_path)

    # target_path = os.path.join(dir_path, f'generated_{len(generated_samples)}_samples_temp10_idx{idx}_cuda{device_number}.jsonl')
    # with open(target_path, 'w') as h:
    #     for generated_sample in generated_samples:
    #         h.write(json.dumps(generated_sample) + "\n")
    # print(f"Saved to {target_path}")


def generate_sample_with_vllm(
    pretrained_model_name_or_path,
    tokenizer_name_or_path,
    soft_prompt_dir,
    public_soft_token_count,
    temp,
    target_count,
    batch_size,
    tensor_parallel_size,
    device='cuda:0'
):
    pretrained_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)
    if not tokenizer.pad_token_id:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    soft_token_embeddings_path = os.path.join(soft_prompt_dir, 'public_soft_token_embeddings.pth')
    soft_token_embeddings_state_dict = torch.load(soft_token_embeddings_path, map_location=device)  
    soft_tokens_embeddings = torch.nn.Embedding(public_soft_token_count, pretrained_model.config.hidden_size)
    soft_tokens_embeddings.load_state_dict(soft_token_embeddings_state_dict)
    soft_tokens_embeddings.to(device)
    soft_tokens_embeddings.eval()

    new_tokens = [f"<soft_{i}>" for i in range(public_soft_token_count)]
    tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
    pretrained_model.resize_token_embeddings(len(tokenizer))
    
    with torch.no_grad():
        pretrained_model.get_input_embeddings().weight[-public_soft_token_count:] = soft_tokens_embeddings.weight.data

    temp_dir = os.path.join(soft_prompt_dir, "vllm_temp_model")
    if not os.path.exists(temp_dir) or not os.path.exists(os.path.join(temp_dir, "config.json")):
        print("save vllm temp model...")
        pretrained_model.save_pretrained(temp_dir)
        tokenizer.save_pretrained(temp_dir)
        print("save vllm temp model done!")
    else:
        print(f"{temp_dir} already exists and contains model files. Skipping save.")

    torch.cuda.empty_cache()
    del pretrained_model  
    
    llm = LLM(
        model=temp_dir,
        tokenizer=temp_dir,
        dtype='bfloat16',
        tensor_parallel_size=tensor_parallel_size,
        gpu_memory_utilization=0.95,
        trust_remote_code=True
    )

    sampling_params = SamplingParams(
        max_tokens=2048,
        temperature=temp,
        top_p=1,
        repetition_penalty=1.0,
        stop_token_ids=[tokenizer.eos_token_id, tokenizer.pad_token_id],
    )
    
    prompt = "".join(new_tokens) 
    num_repetitions = (target_count + batch_size - 1) // batch_size
    
    all_synthetic_texts = []
    idx = 0
    for _ in tqdm(range(num_repetitions), desc='Synthetic...'):
        outputs = llm.generate([prompt]*batch_size, sampling_params)
        for output in outputs:
            instruct = output.outputs[0].text.strip()
            all_synthetic_texts.append(dict(idx=idx, synthetic_text=instruct))
            idx += 1
    
    target_path = os.path.join(soft_prompt_dir, f'vllm_generated_{len(all_synthetic_texts)}_samples_temp{temp}.jsonl')
    with open(target_path, 'w') as f:
        for item in all_synthetic_texts:
            f.write(json.dumps(item) + "\n")
    print(f"Saved to {target_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Sythetic samples.")
    parser.add_argument("--pretrained_model_name_or_path", type=str, default="Qwen/Qwen2.5-Coder-7B-Instruct")
    parser.add_argument("--tokenizer_name_or_path", type=str, default="Qwen/Qwen2.5-Coder-7B-Instruct")
    parser.add_argument("--soft_prompt_dir", type=str, default="0514-4-4lambda-qwencoder-7b-instruct-pub256-pri256-tokens-10epoch")
    parser.add_argument("--inference_engine", type=str, default="vllm", choices=["transformers", "vllm"])
    parser.add_argument("--public_soft_token_count", type=int, default=256)
    parser.add_argument("--temp", type=float, default=0.8)
    parser.add_argument("--target_count", type=int, default=200)
    parser.add_argument("--tensor_parallel_size", type=int, default=4)
    parser.add_argument("--device", type=str, default="cuda:0")

    args = parser.parse_args()

    if args.inference_engine == "transformers":
        generate_sample_with_transformers(
            pretrained_model_name_or_path=args.pretrained_model_name_or_path,
            tokenizer_name_or_path=args.tokenizer_name_or_path,
            soft_prompt_dir=args.soft_prompt_dir,
            public_soft_token_count=args.public_soft_token_count,
            temp=args.temp,
            target_count=args.target_count,
            device=args.device
        )
    else:
        generate_sample_with_vllm(
            pretrained_model_name_or_path=args.pretrained_model_name_or_path,
            tokenizer_name_or_path=args.tokenizer_name_or_path,
            soft_prompt_dir=args.soft_prompt_dir,
            public_soft_token_count=args.public_soft_token_count,
            temp=args.temp,
            target_count=args.target_count,
            batch_size=200,
            tensor_parallel_size=args.tensor_parallel_size,
            device=args.device 
        )