import json

import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.dataset import JsonDataset
from src.modeling import LLaMA
from src.modeling_30b import LLaMA30B
from src.modeling_args import LoraModelArgs, ModelArgs
from src.modeling_lora import LoraLLaMA
from src.modeling_lora_30b import LoraLLaMA30B
from src.tokenizer import Tokenizer
from src.trainer import DistributedTrainer
from src.utils import setup_model_parallel, extract_logits, json_load


def extract_logits_to_datalist(
        datalist: list,
        save_filename: str,
        data_loader: DataLoader,
        trainer: DistributedTrainer,
        diversity: int,
        vocab_size: int
):
    datalist_line_index = 0
    with open(save_filename, 'w', encoding='utf-8') as writer:
        for data in tqdm(data_loader):
            results = [[] for _ in range(len(data['instruction']))]
            for i in range(diversity):
                example = trainer.prepare_for_training(
                    instructions=data['instruction'],
                    outputs=data['output'][i]
                )
                with torch.no_grad():
                    logits = trainer.model.forward(example.tokens).detach().cpu()
                masks = (example.labels != -100)
                selected = torch.masked_select(logits, masks.unsqueeze(dim=-1))
                selected = torch.reshape(selected, shape=(-1, vocab_size))
                split_size = tuple(masks.sum(dim=-1).cpu().numpy())
                selected = torch.split(selected, split_size, dim=0)
                for j, item in enumerate(selected):
                    results[j].append(extract_logits(item, p=0.8, max_n=10, min_n=5))
            for i in range(len(results)):
                written_data = datalist[datalist_line_index]
                written_data['output'] = written_data['output'][:len(results[i])]
                written_data['logits'] = results[i]
                writer.write(json.dumps(written_data) + '\n')
                datalist_line_index += 1
                if datalist_line_index % 1000 == 0:
                    print(f"Processing index {datalist_line_index} of {len(datalist)} ...")


def main(
        ckpt_dir: str,
        train_file: str,
        batch_size: int = 1,
        diversity: int = 1,
        model_type: str = "30B",
        max_seq_len: int = 512,
        lora_rank: int = 16,
        tokenizer_path: str = None,
        config_file: str = None,
        seed: int = None
):
    config_file = f"config/{model_type}/params.json" if (
            config_file is None
    ) else config_file
    tokenizer_path = 'config/tokenizer.model' if (
            tokenizer_path is None
    ) else tokenizer_path
    seed = 1 if seed is None else seed
    local_rank, world_size = setup_model_parallel(
        use_float16=True, seed=seed
    )

    if lora_rank > 0:  # using lora
        params = LoraModelArgs(
            max_seq_len=max_seq_len,
            local_rank=local_rank,
            world_size=world_size,
            r=lora_rank
        ).from_json(config_file)
        model = LoraLLaMA(params) if (
                model_type != "30B"
        ) else LoraLLaMA30B(params)
    else:  # not using lora
        params = ModelArgs(
            max_seq_len=max_seq_len,
            local_rank=local_rank,
            world_size=world_size
        ).from_json(config_file)
        model = LLaMA(params) if (
                model_type != "30B"
        ) else LLaMA30B(params)

    model.train(False)

    dataset = JsonDataset(filename=train_file)
    datalist = json_load(train_file)
    data_loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False)
    optimizer = torch.optim.Adam(model.parameters())
    trainer = DistributedTrainer(
        model=model,
        tokenizer=Tokenizer(tokenizer_path),
        optimizer=optimizer,
        eval_batch_size=1,
    )
    trainer.load_distributed_model(ckpt_dir)

    extract_logits_to_datalist(
        datalist=datalist,
        save_filename=f'teacher-{model_type}-logits.jsonl',
        data_loader=data_loader,
        trainer=trainer,
        diversity=diversity,
        vocab_size=params.vocab_size
    )


if __name__ == "__main__":
    fire.Fire(main)
