import os
import random

import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.dataset import JsonDataset
from src.modeling_args import LoraModelArgs
from src.modeling_lora import LoraLLaMA
from src.modeling_lora_30b import LoraLLaMA30B
from src.modeling_lora_hf import LoraLLaMAHF
from src.tokenizer import Tokenizer
from src.trainer import DistributedTrainer
from src.utils import setup_model_parallel


def main(
        ckpt_dir: str,
        save_dir: str,
        train_file: str,
        dataset_size: str = '200k',
        model_type: str = "7B",
        diversity: int = 1,
        max_seq_len: int = 512,
        max_batch_size: int = 1,
        eval_batch_size: int = 64,
        accumulation_steps: int = 1,
        lr: float = 1e-5,
        epochs: int = 1,
        lora_rank: int = 128,
        log_dir: str = "log",
        tokenizer_path: str = None,
        config_file: str = None,
        seed: int = None
):
    tokenizer_path = 'config/tokenizer.model' if tokenizer_path is None else tokenizer_path
    config_file = f"config/{model_type}/params.json" if config_file is None else config_file
    seed = 1 if seed is None else seed
    local_rank, world_size = setup_model_parallel(
        use_float16=True, seed=seed
    )
    params = LoraModelArgs(
        max_seq_len=max_seq_len,
        local_rank=local_rank,
        world_size=world_size,
        r=lora_rank
    ).from_json(config_file)

    if model_type == "30B":
        model = LoraLLaMA30B(params)
    elif 'orca' in model_type.lower():
        model = LoraLLaMAHF(params)
    else:
        model = LoraLLaMA(params)

    dataset = JsonDataset(filename=train_file)[: int(dataset_size.replace('k', '000'))]
    data_loader = DataLoader(dataset, batch_size=max_batch_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    trainer = DistributedTrainer(
        model=model,
        tokenizer=Tokenizer(tokenizer_path),
        optimizer=optimizer,
        accumulation_steps=accumulation_steps,
        eval_batch_size=eval_batch_size,
        log_dir=os.path.join(log_dir),
    )
    trainer.load(ckpt_dir)
    for epoch in range(epochs):
        for data in tqdm(data_loader):
            index = random.randint(0, diversity - 1)
            outputs = trainer.train(
                instructions=data['instruction'],
                outputs=data['output'][index]
            )
            if trainer.step % 100 == 0:
                print(f'step {trainer.step} ----------------------------------')
                print("LOSS: ", outputs.loss.item())
                predict = trainer.predict(
                    outputs.logits, data['instruction'], data['output'][index]
                )[0]
                print(predict['instruction'] + predict['output'])
        trainer.save(os.path.join(save_dir, f"epoch-{epoch + 1}"))


if __name__ == "__main__":
    fire.Fire(main)
