import os
import sys
import argparse
import torch
from transformers import GenerationConfig


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

project_root_path = os.environ["PROJECT_PATH"]
sys.path.append(project_root_path)

from Data.load_data import DatasetInfo
from Model.load_model import load_base_model
from config_pool import MODEL_POOL, DATASET_POOL
from inference import Inference


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="")

    parser.add_argument("--model_name", type=str, default="", choices=MODEL_POOL)
    parser.add_argument("--dataset", type=str, default="mgsm", choices=DATASET_POOL)
    parser.add_argument("--max_output_token", type=int, default=2048)
    parser.add_argument("--print_model_parameter", action="store_true")

    parser.add_argument("--do_sample", action='store_true')
    parser.add_argument("--language", type=str, default="en")

    args = parser.parse_args()


    print("********** Try to load model **********\n")
    model, tokenizer, config = load_base_model(args)
    if args.print_model_parameter:
        print("********** Module Name and Size **********\n")
        for param_tensor in model.state_dict():
            print(param_tensor,'\t',model.state_dict()[param_tensor].size())

    model_info = {
        "model_name": args.model_name,
        "model_ckpt": model,
        "tokenizer": tokenizer,
        "model_config": config,
        "generation_config": GenerationConfig(),
        "do_sample": args.do_sample,
        "max_output_token": args.max_output_token
    }
    dataset_info = {
        "dataset_name": args.dataset,
    }

    print(f"***** Model Name: *****\n{args.model_name}")
    print(f"***** Dataset Name: *****\n{args.dataset}")
    print(f"***** Dataset Size: *****\n{DatasetInfo(args.dataset).data_size}")

    dataset_info["language"] = args.language
    Infer = Inference(model_info, dataset_info)
    Infer.dataset_inference()
    