import os
import torch
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import defaultdict
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup
from typing import Optional
from vllm import LLM, SamplingParams
import os
os.environ["VLLM_USE_V1"] = "0"
import numpy as np
from datetime import datetime
from pathlib import Path
import json
from tqdm.auto import tqdm
import math

from pathlib import Path
import math
from tqdm.auto import tqdm
import json
from datetime import datetime

from transformers import AutoTokenizer, AutoModelForCausalLM, get_cosine_schedule_with_warmup

from torch.utils.data import RandomSampler

from pathlib import Path
import math
from tqdm.auto import tqdm
import json
from datetime import datetime

from peft import LoraConfig, get_peft_model, PeftModel

from datasets import load_from_disk

from tqdm import tqdm
os.environ["VLLM_USE_V1"] = "0"
from vllm import LLM, SamplingParams
import argparse

from awdpo.evaluator import evaluate_models
from awdpo.eval_config import EvalConfig


if __name__ == "__main__":
    # Model and tokenizer args   
    parser = argparse.ArgumentParser(description="Weighted Live DPO Training")
    
    parser.add_argument("--peft_model", default=None)
    parser.add_argument("--model_name", default="Qwen/Qwen2.5-0.5B")
    parser.add_argument("--output_dir", default="outputs/Qwen25_05B_AWDPO_Gsm8k")
    parser.add_argument("--run_name", default="Qwen25_05B_AWDPO_gsm8k_reasoner")
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--bf16", action = "store_true", default = True)
    parser.add_argument("--eval_mode", type = str, default = 'no_shot')
    parser.add_argument("--use_lora", action = 'store_true')
    parser.add_argument("--num_generations", type=int, default=8)
    parser.add_argument("--max_completion_length", type = int, default = 500)
    parser.add_argument("--data_directory", type = str, default = "../data/gsm8k_fewshot_qwen25")
    parser.add_argument("--eval_dataset", type=str, default = 'gsm8k')
    
    args = parser.parse_args()
    
    print(f"eval_dataset: {args.eval_dataset}")
    
    model_name = args.model_name
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        use_cache=False
    ).to("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer.pad_token = tokenizer.eos_token
    
    config = EvalConfig(
        output_dir = args.output_dir,
        num_generations = args.num_generations,
        max_completion_length = args.max_completion_length,
        temperature = args.temperature,
        bf16 = args.bf16,
        use_lora = args.use_lora,
        eval_dataset = args.eval_dataset
    )
    
    if args.use_lora:
        peft_model = PeftModel.from_pretrained(model, args.peft_model,
                                       use_cache=False
                                      ).to("cuda" if torch.cuda.is_available() else "cpu")
    else:
        peft_model = None
        
    dataset = load_from_disk(args.data_directory)
    
    SYSTEM_PROMPT = "Please respond in this specific format ONLY:\n<reasoning>\n input your reasoning behind your answer in between these reasoning tags.\n</reasoning>\n<answer>\nyour answer in between these answer tags.\n</answer>\n"
        
    evaluator = evaluate_models(model, tokenizer, config, dataset, SYSTEM_PROMPT, peft_model = peft_model)
    
    evaluator.evaluate()
    
    
        

    
    
    



