from pathlib import Path

import pandas as pd
import torch
from evaluate_finetuned import FinetunedAll
from generate import generate_batch
from scripts.prepare_alpaca_no_prompt import generate_no_prompt
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm


class Data(Dataset):
    def __init__(self, shard_id, num_shards, csv_pth, delimiter, tokenizer):
        self.data = pd.read_csv(csv_pth)
        # Split the datasets in num_shards and select shard_id
        self.data = self.data.iloc[
            shard_id
            * len(self.data)
            // num_shards : (shard_id + 1)
            * len(self.data)
            // num_shards
        ]
        self.delimiter = delimiter
        self.tokenizer = tokenizer

        self.texts = (self.data["txt1"] + delimiter + self.data["txt2"]).tolist()
        self.prompts = [generate_no_prompt(i) for i in self.texts]

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, index):
        prompts = self.prompts[index]
        texts = self.texts[index]
        tokens = self.tokenizer.encode(prompts, bos=True, eos=False, device="cpu")
        return texts, tokens


def custom_collate_fn(batch):
    texts = [sample[0] for sample in batch]
    tokens = [sample[1] for sample in batch]
    return texts, tokens


@torch.no_grad()
def main(
    shard_id, num_shards, checkpoint_path, csv_pth, delimiter, batch_size, seed=1234
):
    model = FinetunedAll(checkpoint_path=checkpoint_path, seed=seed)
    model.model.eval()

    dataset = Data(shard_id, num_shards, csv_pth, delimiter, model.tokenizer)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=4,
        shuffle=False,
        collate_fn=custom_collate_fn,
    )

    outputs = []
    inputs = []
    for texts, tokens_in in tqdm(loader, total=len(loader)):
        tokens_in = [i.to(model.model.device) for i in tokens_in]
        tokens_out = generate_batch(
            model.model,
            tokens_in,
            max_new_tokens=500,
            max_seq_length=500,
            top_k=200,
            eos_id=model.tokenizer.eos_id,
            temperature=0.8,
        )
        assert len(tokens_out) == len(texts), f"{len(tokens_out)} != {len(texts)}"
        for token_out, input_text in zip(tokens_out, texts):
            try:
                decoded_token = model.tokenizer.decode(token_out)
                decoded_token = decoded_token.split("### Response:")[1].strip()
                outputs.append(decoded_token)
                inputs.append(input_text)
            except IndexError:
                continue

    df = pd.DataFrame({"input": inputs, "output": outputs})

    data_name = Path(checkpoint_path).parent.stem
    checkpoint = Path(checkpoint_path).stem.split("-")[1]
    out_pth = csv_pth.replace(
        ".csv",
        f"_all_{data_name}_{checkpoint}_s{seed}_{shard_id}-{num_shards}.csv",
    )
    if Path(out_pth).is_file():
        out_pth = out_pth.replace(".csv", "_v2.csv")
    df.to_csv(out_pth, index=False)
    print(f"Saved to {out_pth}")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("num_shards", type=int, default=1)
    parser.add_argument("shard_id", type=int, default=0)
    parser.add_argument(
        "--checkpoint_path",
        default="out/full/my-edits-3-no-prompt-lr3e-5/iter-000056-ckpt.pth",
        type=Path,
    )
    parser.add_argument("--csv_pth", default="out/inputs/WebVid8M-63k.csv", type=str)
    parser.add_argument("--delimiter", type=str, default="\n&&\n")
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--seed", type=int, default=1234)
    args = parser.parse_args()

    main(
        args.shard_id,
        args.num_shards,
        args.checkpoint_path,
        args.csv_pth,
        args.delimiter,
        args.batch_size,
        args.seed,
    )
