from pathlib import Path

import lightning as L
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

from evaluate_finetuned import FinetunedAll
from generate import generate_batch
from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import EmptyInitOnDevice, lazy_load


class Data(Dataset):
    def __init__(self, shard_id, num_shards, csv_pth, delimiter, tokenizer):
        self.data = pd.read_csv(csv_pth)
        # Split the datasets in num_shards and select shard_id
        self.data = self.data.iloc[
            shard_id
            * len(self.data)
            // num_shards : (shard_id + 1)
            * len(self.data)
            // num_shards
        ]
        self.delimiter = delimiter
        self.tokenizer = tokenizer

        start = "Clouds in the sky&&Airplane in the sky-> Add an airplane\nAerial view of forest&&Aerial view autumn forest-> Change season to autumn\nClouds timelapse&&Sky timelapse-> remove clouds and reveal only sky\nAerial view of a sailboat anchored in the mediterranean sea.&&Aerial view of two sailboat anchored in the mediterranean sea.-> Add one sailboat\n"
        self.txt1 = self.data["txt1"].tolist()  # type: ignore
        self.txt2 = self.data["txt2"].tolist()  # type: ignore
        self.texts = (
            start + self.data["txt1"] + "&&" + self.data["txt2"] + "->"
        ).tolist()  # type: ignore

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        texts = self.texts[index]
        tokens = self.tokenizer.encode(texts, bos=True, eos=False, device="cpu")
        txts = (self.txt1[index], self.txt2[index])
        return txts, tokens


def custom_collate_fn(batch):
    texts = [sample[0] for sample in batch]
    tokens = [sample[1] for sample in batch]
    return texts, tokens


@torch.no_grad()
def main(
    shard_id,
    num_shards,
    checkpoint_path,
    csv_pth,
    delimiter,
    batch_size,
    seed=1234,
):
    fabric = L.Fabric(devices=1)
    dtype = (
        torch.bfloat16
        if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported()
        else torch.float32
    )

    with EmptyInitOnDevice(
        device=fabric.device,
        dtype=dtype,
    ):
        model = LLaMA.from_name("7B")

    checkpoint = lazy_load(checkpoint_path)
    model.load_state_dict(checkpoint)
    model.eval()
    model = fabric.setup_module(model)

    tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
    tokenizer = Tokenizer(tokenizer_path)

    dataset = Data(shard_id, num_shards, csv_pth, delimiter, tokenizer)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=4,
        shuffle=False,
        collate_fn=custom_collate_fn,
    )

    outputs = []
    inputs = []
    for texts, tokens_in in tqdm(loader, total=len(loader)):
        tokens_in = [i.to(model.device) for i in tokens_in]
        tokens_out = generate_batch(
            model,
            tokens_in,
            max_new_tokens=250,
            max_seq_length=250,
            top_k=200,
            eos_id=13,  # equivalent to \n
            temperature=0.8,
        )
        assert len(tokens_out) == len(texts), f"{len(tokens_out)} != {len(texts)}"
        for input_text, token_out in zip(texts, tokens_out):
            try:
                decoded_token = tokenizer.decode(token_out)
                decoded_token = (
                    decoded_token.split(input_text[1])[1]
                    .replace("->", "")
                    .split("\n")[0]
                    .strip()
                )
                outputs.append(decoded_token)
                input_txt = f"{input_text[0]}\n&&\n{input_text[1]}"
                inputs.append(input_txt)
            except IndexError:
                continue

    df = pd.DataFrame({"input": inputs, "output": outputs})

    data_name = Path(checkpoint_path).parent.stem
    checkpoint = Path(checkpoint_path).stem.split("-")[1]
    out_pth = csv_pth.replace(
        ".csv",
        f"_prompt_{data_name}_{checkpoint}_s{seed}_{shard_id}-{num_shards}.csv",
    )
    if Path(out_pth).is_file():
        out_pth = out_pth.replace(".csv", "_v2.csv")
    df.to_csv(out_pth, index=False)
    print(f"Saved to {out_pth}")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("num_shards", type=int, default=1)
    parser.add_argument("shard_id", type=int, default=0)
    parser.add_argument(
        "--checkpoint_path",
        default="./checkpoints/lit-llama/7B/lit-llama.pth",
        type=Path,
    )
    parser.add_argument("--csv_pth", default="out/inputs/WebVid2M-1c6M.csv", type=str)
    parser.add_argument("--delimiter", type=str, default="\n&&\n")
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--seed", type=int, default=1234)
    args = parser.parse_args()

    main(
        args.shard_id,
        args.num_shards,
        args.checkpoint_path,
        args.csv_pth,
        args.delimiter,
        args.batch_size,
        args.seed,
    )
