import os
import torch
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import defaultdict
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup
from typing import Optional
from vllm import LLM, SamplingParams
import os
os.environ["VLLM_USE_V1"] = "0"
import numpy as np
from datetime import datetime
from pathlib import Path
import json
from tqdm.auto import tqdm
import math

from pathlib import Path
import math
from tqdm.auto import tqdm
import json
from datetime import datetime

from transformers import AutoTokenizer, AutoModelForCausalLM, get_cosine_schedule_with_warmup

from torch.utils.data import RandomSampler

from pathlib import Path
import math
from tqdm.auto import tqdm
import json
from datetime import datetime

from peft import LoraConfig, get_peft_model, PeftModel

from datasets import load_dataset

from tqdm import tqdm
os.environ["VLLM_USE_V1"] = "0"
from vllm import LLM, SamplingParams
import argparse

from awdpo.generator import VLLM_Generator
from awdpo.rewards import *
from awdpo.utils import (
    create_no_shot_prompt, 
    convert_messages_to_chatml, 
    create_few_shot_prompt
    )

import re

class evaluate_models:
      def __init__(self, model, tokenizer, config, dataset, system_prompt,  peft_model = None, beta = 0.1):
          self.model = model.to(config.device)
          self.tokenizer = tokenizer
          self.config = config
          self.data = dataset
          self.peft_model = peft_model
          self.SYSTEM_PROMPT = system_prompt
          self.llm = VLLM_Generator(model, self.config)
          
          if self.config.eval_mode == 'no_shot' and self.config.use_lora == True:
              self.llm.move_model_to_vllm()

          self.step = 0
      
      def save_responses(self, current_batch):
          """Save responses to a JSON file"""

          timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
          filename = f"response_log_{self.step}_awdpo.json"

          output_dir = Path(f"{self.config.output_dir}/generated_responses")
          output_dir.mkdir(exist_ok=True, parents=True)

          with open(output_dir / filename, 'w') as f:
              json.dump(current_batch, f, indent=2)

          print(f"Saved {len(current_batch)} responses to {output_dir / filename}")

          current_batch.clear()

      def _generate(self, no_shot_input_ids):

          no_shot_output = self.llm.generate(prompt_token_ids=[ids.tolist() for ids in no_shot_input_ids])

          no_shot_token_ids = [output.token_ids for request in no_shot_output for output in request.outputs]

  
          no_shot_completion_ids = [torch.tensor(ids, device=self.config.device, dtype=torch.long) for ids in no_shot_token_ids]
        
          no_shot_completions = [self.tokenizer.decode(seq, skip_special_tokens=True)
                                for seq in no_shot_token_ids]

          return no_shot_completions

      def evaluate(self):
          # Note that BBH and AIME do not have training subsets
          if self.config.eval_dataset == 'aime2024' or self.config.eval_dataset == 'bbh':
              data = self.data['train']
          else:
              data = self.data['test']

          for ex in tqdm(data, desc="Evaluating questions"):
              if self.config.eval_dataset == 'math500':
                question = ex['problem']
                answer = ex['solution']
              elif self.config.eval_dataset == 'aime2024':
                question = ex['Problem']
                answer = ex['Answer']
              elif self.config.eval_dataset == 'bbh':
                question = ex['input']
                answer = ex['target']
              else:
                question = ex['question']
                answer = ex['answer']

              msgs_no = create_no_shot_prompt(question, system_prompt=self.SYSTEM_PROMPT)
              chatml_no  = convert_messages_to_chatml(msgs_no)

              if self.config.eval_mode == 'no_shot':

                  no_shot_input_ids = self.tokenizer([chatml_no], return_tensors='pt').input_ids.to(self.config.device)

                  completions = self._generate(no_shot_input_ids)
              
              else:
                  few_shot_examples = self.data['train'][0]['few_shot_examples_random']
                  msgs_few = create_few_shot_prompt(question, few_shot_examples,
                                                    system_prompt=self.SYSTEM_PROMPT
                                                   num_examples = 4)

                  chatml_few = convert_messages_to_chatml(msgs_few)

                  few_shot_input_ids = self.tokenizer([chatml_few], return_tensors='pt').input_ids.to(self.config.device)

                  completions = self._generate(few_shot_input_ids)

              expanded_no_shot_prompts = [chatml_no for _ in range(self.config.num_generations)]
              expanded_answers = [answer for _ in range(self.config.num_generations)]

              accuracy_reward_no = accuracy_reward(expanded_no_shot_prompts, completions, expanded_answers, dataset=self.config.eval_dataset)

              current_batch_responses = [
                              {
                                  'timestamp': datetime.now().isoformat(),
                                  'question': q,
                                  'expected_answer': a,
                                  'full_response_no_shot': no_resp,
                                  'accuracy_reward': acc_no
                              }
                              for q, a, no_resp, acc_no in zip(expanded_no_shot_prompts, 
                                                               expanded_answers, 
                                                               completions, 
                                                               accuracy_reward_no)
                          ]

              self.save_responses(current_batch_responses)
              self.step += 1
