import os
import random

import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.dataset import DistillingDataset
from src.modeling_args import LoraModelArgs
from src.modeling_lora import LoraLLaMA
from src.modeling_lora_30b import LoraLLaMA30B
from src.tokenizer import Tokenizer
from src.trainer import DistributedTrainer
from src.utils import setup_model_parallel


def main(
        ckpt_dir: str,
        save_dir: str,
        train_file: str,
        dataset_size: str = '200k',
        model_type: str = "7B",
        max_seq_len: int = 512,
        max_batch_size: int = 1,
        eval_batch_size: int = 64,
        accumulation_steps: int = 1,
        diversity: int = 1,
        alpha: float = 1.0,
        lr: float = 1e-5,
        epochs: int = 1,
        lora_rank: int = 128,
        log_dir: str = "log",
        config_file: str = None,
        tokenizer_path: str = None,
        seed: int = None
):
    config_file = f"config/{model_type}/params.json" if (
            config_file is None
    ) else config_file
    tokenizer_path = 'config/tokenizer.model' if (
            tokenizer_path is None
    ) else tokenizer_path
    seed = 1 if seed is None else seed
    local_rank, world_size = setup_model_parallel(
        use_float16=True, seed=seed
    )
    params = LoraModelArgs(
        max_seq_len=max_seq_len,
        local_rank=local_rank,
        world_size=world_size,
        r=lora_rank
    ).from_json(config_file)
    model = LoraLLaMA(params) if (
            model_type != "30B"
    ) else LoraLLaMA30B(params)

    dataset = DistillingDataset(filename=train_file)[: int(dataset_size.replace('k', '000'))]
    data_loader = DataLoader(dataset, batch_size=max_batch_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    trainer = DistributedTrainer(
        model=model,
        tokenizer=Tokenizer(tokenizer_path),
        optimizer=optimizer,
        accumulation_steps=accumulation_steps,
        eval_batch_size=eval_batch_size,
        log_dir=log_dir,
    )
    trainer.load(ckpt_dir)
    for epoch in range(epochs):
        for data in tqdm(data_loader):
            index = random.randint(0, diversity - 1)
            outputs = trainer.distill(
                instructions=data['instruction'],
                outputs=data['output'][index],
                logits_dicts_list=data['logits'][index],
                alpha=alpha
            )
            if trainer.step % 100 == 0:
                print(f'step {trainer.step} of {len(data_loader)} ----------------------------------')
                print("LOSS: ", outputs.loss.item())
                print("DISTILL LOSS: ", outputs.distill_loss.item())
                predict = trainer.predict(
                    outputs.logits, data['instruction'], data['output'][index]
                )[0]
                print(predict['instruction'] + predict['output'])
        trainer.save(os.path.join(save_dir, f"epoch-{epoch + 1}"))


if __name__ == "__main__":
    fire.Fire(main)
