import argparse
import csv
import os
import time
import random

import numpy as np
import torch
from tqdm import tqdm
from dataset import get_loaders 
from utils import count_params, get_model, set_seed


def set_seed(random_seed=1234):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)



@torch.no_grad()
def compute_ppl(model, data_loader, device):
    """
    计算语言模型的 PPL (Perplexity)
    """
    nlls = []
    for batch in tqdm(data_loader, desc="Evaluating PPL"):
        batch = batch.to(device)
        outputs = model(batch)
        logits = outputs.logits

        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = batch[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1)
        )
        nlls.append(loss)

    ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
    return ppl


def eval_ppl(
    output_dir,
    model,
    tokenizer,
    datasets=["wikitext2", "ptb", "c4"],
    max_seq_len=128,
    batch_size=4,
    device="cuda",
    add_bos_to_every=False,
):
    """
    对指定数据集计算 PPL，并将结果保存为 CSV
    """
    os.makedirs(output_dir, exist_ok=True)
    csv_path = os.path.join(output_dir, "ppl_bos.csv" if add_bos_to_every else "ppl.csv")
    csv_header = []
    csv_values = []

    metrics = {}
    for dataset_name in datasets:
        t0 = time.perf_counter()
        _, test_loader = get_loaders(dataset_name, tokenizer, max_seq_len, batch_size, add_bos_to_every)
        ppl = compute_ppl(model, test_loader, device)
        metrics[dataset_name] = ppl

        print(f"PPL-{dataset_name}: {ppl:.2f} | add_bos_to_every={add_bos_to_every} | time={time.perf_counter() - t0:.1f}s")
        csv_header.append(f"ppl_{dataset_name}")
        csv_values.append(ppl)

    mem = torch.cuda.memory_allocated() / 1024 / 1024
    nparams = count_params(model)
    print(f"Model params: {nparams}, GPU memory occupied: {mem:.1f} MiB")

    with open(csv_path, "w", encoding="utf8", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(csv_header + ["params", "mem"])
        writer.writerow(csv_values + [nparams, mem])



def generate_txt(
    output_dir,
    model,
    tokenizer,
    input_prompt="The Leaning Tower of Pisa is known for",
    num_output=5,
    top_k=50,
    top_p=0.95,
    temperature=1.0,
    max_seq_len=128,
    device="cuda",
):
    # generate a few samples
    txt_path = os.path.join(output_dir, "gen_text.txt")
    inputs = tokenizer(input_prompt, return_tensors="pt")["input_ids"].to(device)
    input_len = inputs[0].size(0)

    with open(txt_path, "w", encoding="utf8") as f:
        f.write("=== input ===\n")
        f.write(f"{input_prompt}\n")

    for i in range(num_output):
        with torch.no_grad():
            generation_output = model.generate(
                input_ids=inputs,
                do_sample=True,
                top_k=top_k,
                top_p=top_p,
                temperature=temperature,
                max_length=(input_len + max_seq_len),
                min_length=(
                    input_len + max_seq_len
                ),  # forced output length (to avoid <EOS> sampling)
                return_dict_in_generate=True,
            )
        s = generation_output.sequences[0]
        output_len = len(s)
        output = tokenizer.decode(s)

        print(f"=== output {i} | leng gen {output_len-input_len} + input {input_len}\n")
        print(output)

        with open(txt_path, "a", encoding="utf8") as f:
            f.write(
                f"=== output {i} | leng gen {output_len-input_len} + input {input_len}\n"
            )
            f.write(f"{output}\n")



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model", type=str, default="meta-llama/Llama-2-7b-hf")
    parser.add_argument("--tokenizer", type=str, default=None)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--input_prompt", type=str, default="The Leaning Tower of Pisa is known for")
    parser.add_argument("--num_output", type=int, default=5)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--top_k", type=int, default=50)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--max_seq_len", type=int, default=128)
    parser.add_argument("--output_dir", type=str, default="results")
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--fix_decapoda_config", action="store_true")
    parser.add_argument("--use_bfloat", action="store_true")
    args = parser.parse_args()


    set_seed(args.seed)
    model, tokenizer = get_model(
        base_model=args.base_model,
        tokenizer=args.tokenizer,
        device=args.device,
        fix_decapoda_config=args.fix_decapoda_config,
        use_bfloat=args.use_bfloat,
    )


    eval_ppl(
        output_dir=args.output_dir,
        model=model,
        tokenizer=tokenizer,
        datasets=["wikitext2", "ptb", "c4"],
        max_seq_len=args.max_seq_len,
        batch_size=4,
        device=args.device,
        add_bos_to_every=False,
    )

    generate_txt(
        output_dir=args.output_dir,
        model=model,
        tokenizer=tokenizer,
        input_prompt=args.input_prompt,
        num_output=args.num_output,
        top_k=args.top_k,
        top_p=args.top_p,
        temperature=args.temperature,
        max_seq_len=args.max_seq_len,
        device=args.device,
    )


