import os
import random

import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.dataset import DistillingDataset
from src.modeling_args import LoraModelArgs
from src.modeling_lora import LoraLLaMA
from src.modeling_lora_30b import LoraLLaMA30B
from src.tokenizer import Tokenizer
from src.trainer import DistributedTrainer
from src.utils import setup_model_parallel


def main(
        ckpt_dir: str,
        save_dir: str,
        train_file1: str,
        train_file2: str,
        model_type: str = "7B",
        max_seq_len: int = 512,
        max_batch_size: int = 1,
        eval_batch_size: int = 64,
        accumulation_steps: int = 1,
        diversity: int = 1,
        alpha: float = 1.0,
        beta: float = 1.0,
        lr: float = 1e-5,
        epochs: int = 1,
        lora_rank: int = 128,
        log_dir: str = "log",
        config_file: str = None,
        tokenizer_path: str = None,
        seed: int = None
):
    config_file = f"config/{model_type}/params.json" if (
            config_file is None
    ) else config_file
    tokenizer_path = 'config/tokenizer.model' if (
            tokenizer_path is None
    ) else tokenizer_path
    seed = 1 if seed is None else seed
    local_rank, world_size = setup_model_parallel(
        use_float16=True, seed=seed
    )
    params = LoraModelArgs(
        max_seq_len=max_seq_len,
        local_rank=local_rank,
        world_size=world_size,
        r=lora_rank
    ).from_json(config_file)
    model = LoraLLaMA(params) if (
            model_type != "30B"
    ) else LoraLLaMA30B(params)

    dataset1 = DistillingDataset(filename=train_file1)
    dataset2 = DistillingDataset(filename=train_file2)
    data_loader1 = DataLoader(dataset1, batch_size=max_batch_size)
    data_loader2 = DataLoader(dataset2, batch_size=max_batch_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    trainer = DistributedTrainer(
        model=model,
        tokenizer=Tokenizer(tokenizer_path),
        optimizer=optimizer,
        accumulation_steps=accumulation_steps,
        eval_batch_size=eval_batch_size,
        log_dir=log_dir,
    )
    trainer.load(ckpt_dir)
    for epoch in range(epochs):
        for data1, data2 in tqdm(zip(data_loader1, data_loader2), total=len(data_loader1)):
            index = random.randint(0, diversity - 1)
            outputs = trainer.distill(
                instructions=data1['instruction'],
                outputs=data1['output'][index],
                logits_dicts_list=data1['logits'][index],
                alpha=alpha,
                logits_dicts_list2=data2['logits'][index],
                beta=beta
            )
            if trainer.step % 100 == 0:
                print(f'step {trainer.step} of {len(data_loader1)} ----------------------------------')
                print("LOSS: ", outputs.loss.item())
                print("DISTILL LOSS: ", outputs.distill_loss.item())
                print("DISTILL LOSS 2: ", outputs.distill_loss2.item())
                predict = trainer.predict(
                    outputs.logits, data1['instruction'], data1['output'][index]
                )[0]
                print(predict['instruction'] + predict['output'])
        trainer.save(os.path.join(save_dir, f"epoch-{epoch + 1}"))


if __name__ == "__main__":
    fire.Fire(main)
