from data_utils.prompt_datasets import PromptDataset
from transformers import GenerationConfig
import os
import nltk
#nltk.download("punkt")

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import json
from utils import print_rank, save_rank, all_gather

from rouge_metric import compute_metrics

torch.set_num_threads(4)


def prepare_dataset_main(args, tokenizer):
    data = {}
    data["test"] = PromptDataset(args, tokenizer, "train", args.data_dir, args.dev_num)

    return data


def run_model(args, tokenizer, model, dataset: PromptDataset, epoch, device):
    
    collate_fn = dataset.collate
    dp_world_size = dist.get_world_size()
    dp_rank = dist.get_rank()
    dp_group = None
    
    sampler = DistributedSampler(dataset, shuffle=False, drop_last=False, rank=dp_rank, num_replicas=dp_world_size)
    dataloader = DataLoader(
        dataset, sampler=sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=collate_fn)
    model.eval()
    
    all_query_ids = []
    all_response_ids = []
    all_lm_losses = []
    
    generation_config = GenerationConfig (
        do_sample=args.do_sample,
        top_p=args.top_p,
        top_k=args.top_k,
        temperature=args.temperature,
        no_repeat_ngram_size=args.no_repeat_ngram_size,
        repetition_penalty=args.repetition_penalty,
        max_length=args.max_length,
        min_length=None,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        return_dict_in_generate=True,
        output_scores=True
    )

    with torch.no_grad():
        for it, (model_batch, no_model_batch) in enumerate(tqdm(dataloader, desc=f"Evaluating {args.data_names} ", disable=(dist.get_rank() != 0))):
            if it == 0:
                print_rank("############### Example ###############")
                print_rank(tokenizer.decode(model_batch["input_ids"][0], skip_special_tokens=True))
                print_rank("############### End ###############")
            
            dataset.move_to_device(model_batch, no_model_batch, device)

            all_ids = torch.cat([model_batch["input_ids"], no_model_batch["rest_ids"]], dim=-1)
            input_ids = all_ids[:, :-1]
            attention_mask = (input_ids != tokenizer.pad_token_id).long()
            label_ids = all_ids[:, 1:]
            label_ids = torch.masked_fill(label_ids, label_ids==tokenizer.pad_token_id, -100)
            label_ids[:, :model_batch["input_ids"].size(1)-1] = -100  
            if args.model_type in ["gpt2"]:
                position_ids = (torch.cumsum(attention_mask, dim=-1) - 1) * attention_mask
                out = model(input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, return_dict=True)
            else:
                out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
            logits = out.logits
            loss_mask = (label_ids != -100).float()
            loss_func = nn.CrossEntropyLoss(reduction="none")
            lm_loss = loss_func(logits.view(-1, logits.size(-1)), label_ids.view(-1)).view(label_ids.size())
            lm_loss = torch.sum(lm_loss * loss_mask, -1) / torch.sum(loss_mask, -1)
            all_lm_losses.append(lm_loss)

            query_ids = model_batch["input_ids"]
            max_new_tokens = args.max_length - query_ids.size(1)
            gen_out = model.generate(
                **model_batch,
                generation_config=generation_config,
                max_new_tokens=max_new_tokens
            )
            full_ids = gen_out.sequences
            response_ids = full_ids[:, query_ids.size(1):] # remove prompt (may include start token)
            
            query_ids = F.pad(query_ids, (args.max_prompt_length-query_ids.size(1), 0, 0, 0), value=tokenizer.pad_token_id)
            response_ids = F.pad(response_ids, (0, args.max_length-args.max_prompt_length-response_ids.size(1), 0, 0), value=tokenizer.pad_token_id)
            
            all_query_ids.append(query_ids)
            all_response_ids.append(response_ids)

    all_lm_losses = torch.cat(all_lm_losses)
    # Average across ranks
    sum_lm_loss = all_lm_losses.sum().to(torch.float32)
    count_lm = torch.tensor([all_lm_losses.numel()], dtype=torch.float32, device=all_lm_losses.device)
    dist.all_reduce(sum_lm_loss, dist.ReduceOp.SUM, group=dp_group)
    dist.all_reduce(count_lm, dist.ReduceOp.SUM, group=dp_group)
    mean_lm_loss = (sum_lm_loss / count_lm).item()
    # Gather per-sample lm losses across ranks for writing
    all_lm_losses = all_gather(all_lm_losses, dim=0, world_size=dp_world_size, group=dp_group, op="cat")
        
    all_query_ids = torch.cat(all_query_ids)
    # gather along batch dimension (dim=0) across data-parallel ranks
    all_query_ids = all_gather(all_query_ids, dim=0, world_size=dp_world_size, group=dp_group, op="cat")
    # truncate exactly to dataset length
    all_query_ids = all_query_ids[:len(dataset)]

    all_response_ids = torch.cat(all_response_ids)
    all_response_ids = all_gather(all_response_ids, dim=0, world_size=dp_world_size, group=dp_group, op="cat")
    all_response_ids = all_response_ids[:len(dataset)]
        
    return (
        all_lm_losses,
        all_query_ids,
        all_response_ids)


def generate_data_main(args, tokenizer, model, dataset: PromptDataset, split, epoch, device):
    
    lm_loss, query_ids, response_ids = run_model(args, tokenizer, model, dataset, epoch, device)
    mean_lm_loss = lm_loss.mean()
    query_strs = tokenizer.batch_decode(query_ids, skip_special_tokens=True)
    response_strs = tokenizer.batch_decode(response_ids, skip_special_tokens=True)

    # Write outputs only on rank 0 to avoid concurrent overwrites
    if dist.get_rank() == 0:
        os.makedirs(args.save, exist_ok=True)
        with open(os.path.join(args.save, "preds.txt"), "w") as f:
            for q, r in zip(query_strs, response_strs):
                f.write(q.replace("\n", "<n>") + "\t\t" + r.replace("\n", "<n>") + "\n")

        all_preds = [[]]
        for l, q, r in zip(lm_loss, query_strs, response_strs):
            all_preds[0].append((q, q + r, l))
        torch.save(all_preds, os.path.join(args.save, "preds.pt"))

        all_responses = []
        with open(os.path.join(args.save, "answers.jsonl"), "w") as f:    
            for p in all_preds[0]:
                q, r,lm_loss = p
                r = r[len(q):]
                idx = r.find("<|endoftext|>")
                if idx >= 0:
                    r = r[:idx]
                f.write(json.dumps({
                    "prompt":q.replace("<n>", "\n").strip(), "LLM_output": r.replace("<n>", "\n").strip(),"lm_loss": lm_loss.item() if isinstance(lm_loss, torch.Tensor) else lm_loss
                }) + "\n")
                all_responses.append(r.replace("<n>", "\n").strip())
        # Log sanity checks
        save_rank(f"write_count={len(all_preds[0])} dataset_len={len(dataset)}", os.path.join(args.save, "log.txt"))
        if len(all_preds[0]) != len(dataset):
            save_rank("WARNING: generated count != dataset length", os.path.join(args.save, "log.txt"))
    # Ensure rank 0 finished writing before others exit
    dist.barrier()
    # Compute metrics only on rank 0
    if dist.get_rank() == 0:
        gen_res = compute_metrics(all_responses, dataset.answers)
        save_rank(f"metrics: {gen_res}", os.path.join(args.save, "log.txt"))

    mean_gen_length = np.mean([len(tokenizer.encode(s)) for s in response_strs])
