import argparse
from transformers import TrainingArguments

from checkers import ToxicityChecker, BiasChecker, CommonsenseChecker, JusticeChecker, VirtueChecker
from llms import (
    GPT_35_turbo,
    GPT_4,
    Gemini_pro,
    Llama_2_7b_chat,
    Llama_2_7b_chat_local,
    Llama_2_70b_chat,
    Llama_2_70b_chat_real,
    Llama_2_70b_chat_mix,
    Mistral_7b,
    Mistral_7b_local,
    Mistral_medium,
    Orca_2_13b,
    Orca_2_13b_local
)
from model_utils import load_llama_model, load_gpt_model, load_virt_model
from models import AdaptiveTester
from lora_config import LLAMA_LORA_CONFIG, GPT_LORA_CONFIG, PHI_LORA_CONFIG


def main(args):
    virt_model = load_virt_model(ckpt_path=args.virt_ckpt_path, device=0)
    irt_gen, tokenizer = load_llama_model(peft_config=LLAMA_LORA_CONFIG,
                             ckpt_path=args.gen_ckpt_path,
                             use_fp16=args.use_fp16,
                             num_param_tokens=args.num_param_tokens,
                             beta=args.beta,
                             model_name=args.model_name,
                             device=0)
    training_args = TrainingArguments(
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.batch_size,
        fp16=args.use_fp16,
        output_dir=args.gen_out_path,
        logging_steps=args.log_interval,
        save_strategy='epoch',
        save_total_limit=1,
        gradient_accumulation_steps=args.grad_acc,
        num_train_epochs=args.n_epochs,
        label_names=['labels']
    )
    tester = AdaptiveTester(virt_model=virt_model,
                            irt_gen=irt_gen,
                            tokenizer=tokenizer,
                            checker=BiasChecker(),
                            examinee_models=[GPT_35_turbo(), Gemini_pro(), Llama_2_7b_chat(), Mistral_medium()],
                            resume_from_exist=args.resume,
                            static_path=args.static_path,
                            archive_path=args.archive_path,
                            training_args=training_args,
                            training_threshold=args.training_threshold,
                            max_iter=args.max_iter, 
                            res_per_item=args.res_per_item,
                            sample_size=args.sample_size,
                            seed_size=args.seed_size)
    
    tester.hybrid_test()


if __name__ == '__main__':
    _type = 'bias'
    parser = argparse.ArgumentParser(description='GETA hyper-parameters')
    parser.add_argument('--model-name', type=str, default='')
    parser.add_argument('--virt-ckpt-path', type=str, default=f'')   
    parser.add_argument('--gen-ckpt-path', type=str, default=f'')
    parser.add_argument('--num-param-tokens', type=int, default=5)
    parser.add_argument('--beta', type=float, default=0.1, help='weight of the entropy regularization')

    parser.add_argument('--static-path', type=str, default=f'')
    parser.add_argument('--resume', action='store_true')
    parser.add_argument('--archive-path', type=str, default=f'./logs_geta/{_type}_4m_300')
    parser.add_argument('--gen-out-path', type=str, default=f'')
    parser.add_argument('--training-threshold', type=int, default=100, help='start training the generator if collected so many batches of new items')
    parser.add_argument('--max-iter', type=int, default=10)
    parser.add_argument('--res_per_item', type=int, default=4)
    parser.add_argument('--sample-size', type=int, default=10)
    parser.add_argument('--seed-size', type=int, default=50)

    parser.add_argument('--learning-rate', type=float, default=5e-5)
    parser.add_argument('--batch-size', type=int, default=4)
    parser.add_argument('--use-fp16', type=bool, default=True)
    parser.add_argument('--log-interval', type=int, default=5)
    parser.add_argument('--grad-acc', type=int, default=8)
    parser.add_argument('--n-epochs', type=int, default=3)

    args = parser.parse_args()
    
    main(args)

