import os
import pandas as pd
from datasets import load_dataset
from safetensors import safe_open
from datasets import load_from_disk
import itertools
import pandas as pd
import matplotlib.pyplot as plt
import torch
import random
import numpy as np
from tqdm import tqdm
from datasets import load_metric
from rouge_score import rouge_scorer
from sentence_transformers import SentenceTransformer, util
import sys
import pickle
from trl import ORPOConfig, ORPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
import datasets
from datasets import Dataset,load_dataset
current_directory = os.getcwd()
print("Current Directory:", current_directory)
sys.path.append('/mnt/batch/tasks/shared/LS_root/mounts/clusters/llm-preference/code/Users/abhijnan_nath')
# from DPLTrainer import DPL_trainer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from sentence_transformers import SentenceTransformer, util
from datasets import load_metric
import nltk
import pandas as pd 
import sys
import logging
import json
import transformers
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from sklearn.model_selection import train_test_split
import pickle
import torch.nn.functional as F
import seaborn as sns
from collections import defaultdict
import gc
import random 

from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
TextInput = str
PreTokenizedInput = List[str]
EncodedInput = List[int]
TextInputPair = Tuple[str, str]
PreTokenizedInputPair = Tuple[List[str], List[str]]
EncodedInputPair = Tuple[List[int], List[int]]

import copy
import inspect
import random
import warnings
import collections
import math
import os
import random
import re
from dataclasses import dataclass, field
import shutil
from functools import partial
import sys
import time
import warnings
from collections import defaultdict
from collections.abc import Mapping
from contextlib import nullcontext
from copy import deepcopy
from functools import wraps
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM,
    DataCollator,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
    TrainerCallback,
    TrainerState
 
)
from transformers.tokenization_utils_base import BatchEncoding
from transformers.utils import is_torch_fx_proxy
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput, speed_metrics, TrainOutput
from transformers.utils import is_torch_fx_proxy

from transformers import Trainer
#from typing import Optional, Union, Callable, List, Tuple, Dict, Any
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from datasets import Dataset
from contextlib import nullcontext
import pickle
from trl import ORPOConfig, ORPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, DataCollator, DataCollatorWithPadding, default_data_collator, PreTrainedModel, PreTrainedTokenizerBase, Trainer
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_torch_fx_proxy 
from transformers import Trainer, AutoModelForCausalLM
from transformers.optimization import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from typing import Optional, Callable, List, Union, Tuple, Dict
import torch.nn as nn
import warnings
from enum import Enum


import os
 
from typing import Optional

import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    set_seed,
)

from trl import SFTConfig, SFTTrainer
from trl.import_utils import is_npu_available, is_xpu_available
from trl.trainer import ConstantLengthDataset
import bitsandbytes as bnb
optim_8bit = bnb.optim.Adam8bit
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import pipeline
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from sentence_transformers import SentenceTransformer, util
from datasets import load_metric
import re

meteor = load_metric("meteor")

nltk.download('punkt_tab')
# Initialize SentenceTransformer model for semantic similarity
similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')



def load_teacher_model_and_tokenizer(reward_model_path):
    # Load the main model
    model = AutoModelForCausalLM.from_pretrained(
        reward_model_path, 

        trust_remote_code=True, torch_dtype = torch.bfloat16, attn_implementation="flash_attention_2"
    )
    tokenizer = AutoTokenizer.from_pretrained(reward_model_path)
    # print("Running training with script_args.trainer_teacher_rm", script_args.trainer_teacher_rm)

    # Step 2: Load the index.json file to locate the correct safetensor file
    model_dir = reward_model_path
    index_file_path = os.path.join(model_dir, "model.safetensors.index.json")

    with open(index_file_path, "r") as f:
        index_data = json.load(f)

    # Step 3: Get the path for the classification_head.weight and classification_head.bias tensors
    safetensors_path = os.path.join(model_dir, index_data['weight_map']["classification_head.weight"])
    safetensors_bias_path = os.path.join(model_dir, index_data['weight_map']["classification_head.bias"])

    # Step 4: Load the "classification_head.weight" and "classification_head.bias" tensors from the safetensor file
    with safe_open(safetensors_path, framework="pt", device="cpu") as f:
        classification_weight = f.get_tensor("classification_head.weight")
        
    with safe_open(safetensors_bias_path, framework="pt", device="cpu") as f:
        classification_bias = f.get_tensor("classification_head.bias")

    # Step 5: Assign the weight and bias to the appropriate layer in the model
    if hasattr(model, "classification_head"):
        model.classification_head.weight = torch.nn.Parameter(classification_weight)
        model.classification_head.bias = torch.nn.Parameter(classification_bias)
    else:
        # Create the classification head and assign the score weight and bias
        print("explicitly getting the classification or score weights initialized")
        model.classification_head = torch.nn.Linear(
            classification_weight.shape[1], 
            classification_weight.shape[0]
        )
        model.classification_head.weight = torch.nn.Parameter(classification_weight)
        model.classification_head.bias = torch.nn.Parameter(classification_bias)

    # Convert the classification head to the correct dtype
    torch_dtype = torch.bfloat16
    model.classification_head = model.classification_head.to(torch_dtype)

    return model, tokenizer

def compute_bleu(reference, hypothesis):
    smooth = SmoothingFunction().method1
    return sentence_bleu([reference.split()], hypothesis.split(), smoothing_function=smooth)

# Function to compute ROUGE scores
def compute_rouge(reference, hypothesis):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = scorer.score(reference, hypothesis)
    return scores['rouge1'].fmeasure, scores['rouge2'].fmeasure, scores['rougeL'].fmeasure

# Function to compute semantic similarity using Sentence Transformers
def compute_semantic_similarity(reference, hypothesis):
    embedding1 = similarity_model.encode(reference, convert_to_tensor=True)
    embedding2 = similarity_model.encode(hypothesis, convert_to_tensor=True)
    similarity = util.pytorch_cos_sim(embedding1, embedding2).item()
    return similarity


def forward_with_lm_head(model, concatenated_batch, model_kwargs={}):
    if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module

    input_ids = concatenated_batch["input_ids"]
    attention_mask = concatenated_batch["attention_mask"]
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    
    # Define the vocabulary size based on the model's configuration
    vocab_size = model.config.vocab_size  # Adjust this if needed for specific models

    # Validate that the input_ids do not contain invalid tokens
    if torch.any(input_ids >= vocab_size) or torch.any(input_ids < 0):
        print("Invalid tokens detected in input_ids, skipping this batch.")
        return None, None  # Skip processing this batch due to invalid tokens

    with torch.no_grad():
        try:
            outputs = model.base_model(
                input_ids=input_ids, 
                attention_mask=attention_mask, 
                return_dict=True, 
                use_cache=False,
                **model_kwargs
            )
        except RuntimeError as e:
            print(f"RuntimeError during model forward pass: {e}")
            return None, None  # Skip this batch if an error occurs during forward pass

    last_hidden_state = outputs.last_hidden_state

    # Expand the attention mask to match the dimensions of the last_hidden_state to avoid division by zero
    expanded_attention_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()

    # Perform mean pooling: sum hidden states along the sequence dimension and divide by the valid token count
    pooled_hidden_state = torch.sum(last_hidden_state * expanded_attention_mask, dim=1) / torch.clamp(expanded_attention_mask.sum(dim=1), min=1e-9)

    # Ensure the dtype of the pooled hidden state matches the classification head's dtype
    pooled_hidden_state = pooled_hidden_state.to(model.classification_head.weight.dtype)

    # Pass the pooled hidden state to the classification head to get a single logit per sequence
    reward_logits = model.classification_head(pooled_hidden_state)  # Output shape: (batch_size, 1)
    
    print("reward_logits", reward_logits.shape)
    return reward_logits, None



def find_max_token_length(sequences, tokenizer):
    max_batch_length = 0
    for sequence in sequences:
        tokenized = tokenizer(sequence)
        input_ids = tokenized["input_ids"]
        max_batch_length = max(max_batch_length, len(input_ids))
    return max_batch_length
# Function to clean invalid tokens in input_ids
def clean_invalid_tokens(input_ids, vocab_size, unk_token_id):
    # Mask invalid tokens and replace them with the unknown token
    input_ids = torch.where((input_ids < 0) | (input_ids >= vocab_size), torch.tensor(unk_token_id), input_ids)
    return input_ids

# Updated compute_batch_metrics function
def compute_batch_metrics(policy_batches, reference_batches,gold_sequences, prompt_batches, model, tokenizer):
    """
    Computes accuracy, logits, advantage, and additional evaluation metrics between policy and reference models.
    
    Args:
    - policy_batches (list of list of str): Batches of policy sequences.
    - reference_batches (list of list of str): Batches of reference sequences.
    - model (torch.nn.Module): The model used for inference.
    - tokenizer (transformers.PreTrainedTokenizer): The tokenizer used for tokenizing sequences.
    
    Returns:
    - dict: Dictionary containing average accuracy, logits, advantage, and evaluation metrics.
    - pd.DataFrame: DataFrame logging individual sample-level results.
    """
    total_accuracy_policy = 0
    total_advantage_policy = 0
    total_policy_logits = 0
    total_reference_logits = 0
    total_gold_logits = 0
    total_batches = 0
    total_gold_acc_policy = 0
    total_gold_advantage_policy = 0
    total_gold_acc_ref = 0
    total_gold_advantage_ref = 0
    
    

    # Lists to accumulate metrics for averaging later
    all_bleu = []
    all_meteor = []
    all_rouge1 = []
    all_rouge2 = []
    all_rougeL = []
    all_similarity = []
    
    gold_bleu_policy = []
    gold_bleu_ref = []
    gold_meteor_policy= []
    gold_meteor_ref = []
    gold_rouge1_policy= []
    gold_rouge1_ref = []
    gold_rouge2_policy= []
    gold_rouge2_ref = []
    gold_rougeL_policy= []
    gold_rougeL_ref = []
    gold_similarity_policy= []
    gold_similarity_ref = []
    
  

    # List to store individual sample-level results
    results_list = []
    c = 0
    # Iterate over each batch of policy and reference
    for policy_batch, reference_batch, gold_batch, prompt_batch in tqdm(zip(policy_batches, reference_batches,gold_batches,prompt_batches ), desc="Processing Batches", total=len(policy_batches)):

        max_length_policy = find_max_token_length(policy_batch, tokenizer)
        max_length_reference = find_max_token_length(reference_batch, tokenizer)
        max_length_gold = find_max_token_length(gold_batch, tokenizer)
        max_length = max(max_length_policy, max_length_reference, max_length_gold)

        # Tokenize the policy and reference batches with the determined max length
        policy_tokenized = tokenizer(policy_batch, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
        reference_tokenized = tokenizer(reference_batch, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
        gold_tokenized = tokenizer(gold_batch, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
        max_length = script_args.max_length  # Assuming max_length is 1024

        # Truncate the tokenized inputs to the max_length
        policy_tokenized_inputs = policy_tokenized["input_ids"][:, :max_length]
        policy_tokenized_attention_mask = policy_tokenized["attention_mask"][:, :max_length]
        reference_tokenized_inputs = reference_tokenized["input_ids"][:, :max_length]
        reference_tokenized_attention_mask = reference_tokenized["attention_mask"][:, :max_length]
        
        gold_tokenized_inputs = gold_tokenized["input_ids"][:, :max_length]
        gold_tokenized_attention_mask = gold_tokenized["attention_mask"][:, :max_length]

        # Print shapes to verify truncation
        print("Shape of policy_tokenized_inputs:", policy_tokenized_inputs.shape)
        print("Shape of policy_tokenized_attention_mask:", policy_tokenized_attention_mask.shape)
        print("Shape of reference_tokenized_inputs:", reference_tokenized_inputs.shape)
        print("Shape of reference_tokenized_attention_mask:", reference_tokenized_attention_mask.shape)
        print("Shape of gold_tokenized_inputs:", gold_tokenized_inputs.shape)
        print("Shape of gold_tokenized_attention_mask:", gold_tokenized_attention_mask.shape)


        # Define the vocabulary size from the tokenizer
        vocab_size = tokenizer.vocab_size
        unk_token_id = tokenizer.unk_token_id  # Use the unknown token ID
        concatenated_input_ids = torch.cat((policy_tokenized_inputs, reference_tokenized_inputs, gold_tokenized_inputs ), dim=0)
        concatenated_attention_mask = torch.cat((policy_tokenized_attention_mask, reference_tokenized_attention_mask, gold_tokenized_attention_mask), dim=0)

        # Prepare the concatenated batch
        concatenated_batch = {
            "input_ids": concatenated_input_ids,
            "attention_mask": concatenated_attention_mask
        }

        # Call forward_with_lm_head with the concatenated inputs for both policy and reference
        all_rewards, _ = forward_with_lm_head(model, concatenated_batch)

        # Split the reward logits into policy and reference logits
        
        len_policy = policy_tokenized["input_ids"].shape[0]
        policy_reward_logits = all_rewards[:len_policy]
        reference_reward_logits = all_rewards[len_policy:len_policy*2]
        gold_reward_logits = all_rewards[len_policy*2:]
        # Iterate over individual samples in the batch
        for i in range(len(policy_batch)):
            policy_logit = policy_reward_logits[i].item()
            reference_logit = reference_reward_logits[i].item()
            gold_logit = gold_reward_logits[i].item()
            accuracy = float(policy_logit > reference_logit)
#             print("accuracies", accuracy)
            advantage = policy_logit - reference_logit
    
            gold_acc_policy = float(policy_logit > gold_logit)
            gold_advantage_policy = policy_logit - gold_logit
            gold_acc_ref = float(reference_logit > gold_logit)
            gold_advantage_ref = reference_logit - gold_logit
            
     
            bleu_score = compute_bleu(reference_batch[i], policy_batch[i])
            meteor_score = meteor.compute(predictions=[policy_batch[i]], references=[reference_batch[i]])['meteor']
            rouge1, rouge2, rougeL = compute_rouge(reference_batch[i], policy_batch[i])
            semantic_similarity = compute_semantic_similarity(reference_batch[i], policy_batch[i])
            
            bleu_score_policy_gold = compute_bleu(gold_batch[i], policy_batch[i])
            meteor_score_policy_gold = meteor.compute(predictions=[policy_batch[i]], references=[gold_batch[i]])['meteor']
            rouge1_gold_policy, rouge2_gold_policy, rougeL_gold_policy = compute_rouge(gold_batch[i], policy_batch[i])
            semantic_similarity_gold_policy = compute_semantic_similarity(gold_batch[i], policy_batch[i])
            
            bleu_score_ref_gold = compute_bleu(gold_batch[i], reference_batch[i])
            meteor_score_ref_gold = meteor.compute(predictions=[reference_batch[i]], references=[gold_batch[i]])['meteor']
            rouge1_gold_ref, rouge2_gold_ref, rougeL_gold_ref = compute_rouge(gold_batch[i], reference_batch[i])
            semantic_similarity_gold_ref= compute_semantic_similarity(gold_batch[i], reference_batch[i])
            

            results_list.append({
                 "prompt": prompt_batch[i], 
                "policy_summary": policy_batch[i],
                "reference_summary": reference_batch[i],
                "gold_summary": gold_batch[i],
               
                "policy_logit": policy_logit,
                "reference_logit": reference_logit,
                "gold_logit": gold_logit,

                # Policy to reference comparison
                "accuracy": accuracy,
                "advantage": advantage,
                "bleu": bleu_score,
                "meteor": meteor_score,
                "rouge1": rouge1,
                "rouge2": rouge2,
                "rougeL": rougeL,
                "semantic_similarity": semantic_similarity,

                # Policy to gold comparison
                "gold_acc_policy": gold_acc_policy,
                "gold_advantage_policy": gold_advantage_policy,
                "bleu_policy_gold": bleu_score_policy_gold,
                "meteor_policy_gold": meteor_score_policy_gold,
                "rouge1_gold_policy": rouge1_gold_policy,
                "rouge2_gold_policy": rouge2_gold_policy,
                "rougeL_gold_policy": rougeL_gold_policy,
                "semantic_similarity_gold_policy": semantic_similarity_gold_policy,

                # Reference to gold comparison
                "gold_acc_ref": gold_acc_ref,
                "gold_advantage_ref": gold_advantage_ref,
                "bleu_ref_gold": bleu_score_ref_gold,
                "meteor_ref_gold": meteor_score_ref_gold,
                "rouge1_gold_ref": rouge1_gold_ref,
                "rouge2_gold_ref": rouge2_gold_ref,
                "rougeL_gold_ref": rougeL_gold_ref,
                "semantic_similarity_gold_ref": semantic_similarity_gold_ref,
            })

            # Accumulate scores and advantages
            total_accuracy_policy += accuracy
            total_advantage_policy += advantage
            total_policy_logits += policy_logit
            total_reference_logits += reference_logit
            total_gold_logits += gold_logit
 
            
            total_gold_acc_policy += gold_acc_policy
            total_gold_advantage_policy += gold_advantage_policy
            total_gold_acc_ref += gold_acc_ref
            total_gold_advantage_ref += gold_advantage_ref
    
            total_batches += 1
                  
            

            # Accumulate metrics for averaging
            all_bleu.append(bleu_score)
            all_meteor.append(meteor_score)
            all_rouge1.append(rouge1)
            all_rouge2.append(rouge2)
            all_rougeL.append(rougeL)
            all_similarity.append(semantic_similarity)
            
            gold_bleu_policy.append(bleu_score_policy_gold)
            gold_bleu_ref.append(bleu_score_ref_gold)
            gold_meteor_policy.append(meteor_score_policy_gold)
            gold_meteor_ref.append(meteor_score_ref_gold)
            
            gold_rouge1_policy.append(rouge1_gold_policy)
            gold_rouge2_policy.append(rouge2_gold_policy)
            gold_rouge1_ref.append(rouge1_gold_ref)
            gold_rouge2_ref.append(rouge2_gold_ref)
            gold_rougeL_policy.append(rougeL_gold_policy)
            gold_rougeL_ref.append(rougeL_gold_ref)
            gold_similarity_policy.append(semantic_similarity_gold_policy)
            gold_similarity_ref.append(semantic_similarity_gold_ref)
            


    # Create DataFrame from results list
    results_df = pd.DataFrame(results_list)

    # Compute averages after all batches
    average_accuracy_policy = total_accuracy_policy / total_batches
    average_advantage_policy = total_advantage_policy / total_batches
    average_policy_logits = total_policy_logits / total_batches
    average_reference_logits = total_reference_logits / total_batches
    average_gold_logits = total_gold_logits / total_batches
    
    average_accuracy_policy_gold = total_gold_acc_policy / total_batches
    average_advantage_policy_gold = total_gold_advantage_policy / total_batches
    average_accuracy_ref_gold  = total_gold_acc_ref / total_batches
    average_advantage_ref_gold  = total_gold_advantage_ref / total_batches
    
    # Compute average metrics
    average_bleu = np.mean(all_bleu)
    average_meteor = np.mean(all_meteor)
    average_rouge1 = np.mean(all_rouge1)
    average_rouge2 = np.mean(all_rouge2)
    average_rougeL = np.mean(all_rougeL)
    average_similarity = np.mean(all_similarity)
    
        # Compute average metrics for policy vs gold
    average_gold_bleu_policy = np.mean(gold_bleu_policy)
    average_gold_meteor_policy = np.mean(gold_meteor_policy)
    average_gold_rouge1_policy = np.mean(gold_rouge1_policy)
    average_gold_rouge2_policy = np.mean(gold_rouge2_policy)
    average_gold_rougeL_policy = np.mean(gold_rougeL_policy)
    average_gold_similarity_policy = np.mean(gold_similarity_policy)

    # Compute average metrics for reference vs gold
    average_gold_bleu_ref = np.mean(gold_bleu_ref)
    average_gold_meteor_ref = np.mean(gold_meteor_ref)
    average_gold_rouge1_ref = np.mean(gold_rouge1_ref)
    average_gold_rouge2_ref = np.mean(gold_rouge2_ref)
    average_gold_rougeL_ref = np.mean(gold_rougeL_ref)
    average_gold_similarity_ref = np.mean(gold_similarity_ref)
    
    
    # Extend the results_summary dictionary
    results_summary = {
        "average_accuracy_policy": np.round(average_accuracy_policy, 2),
        "average_advantage_policy": np.round(average_advantage_policy, 2),
        "average_policy_logits": np.round(average_policy_logits, 2),
        "average_reference_logits": np.round(average_reference_logits, 2),
        "average_gold_logits": np.round(average_gold_logits, 2),
        "average_accuracy_policy_gold": np.round(average_accuracy_policy_gold, 2),
        "average_advantage_policy_gold": np.round(average_advantage_policy_gold, 2),
        "average_accuracy_ref_gold": np.round(average_accuracy_ref_gold, 2),
        "average_advantage_ref_gold": np.round(average_advantage_ref_gold, 2),
        "average_bleu": np.round(average_bleu, 2),
        "average_meteor": np.round(average_meteor, 2),
        "average_rouge1": np.round(average_rouge1, 2),
        "average_rouge2": np.round(average_rouge2, 2),
        "average_rougeL": np.round(average_rougeL, 2),
        "average_semantic_similarity": np.round(average_similarity, 2),
        "average_gold_bleu_policy": np.round(average_gold_bleu_policy, 2),
        "average_gold_meteor_policy": np.round(average_gold_meteor_policy, 2),
        "average_gold_rouge1_policy": np.round(average_gold_rouge1_policy, 2),
        "average_gold_rouge2_policy": np.round(average_gold_rouge2_policy, 2),
        "average_gold_rougeL_policy": np.round(average_gold_rougeL_policy, 2),
        "average_gold_similarity_policy": np.round(average_gold_similarity_policy, 2),
        "average_gold_bleu_ref": np.round(average_gold_bleu_ref, 2),
        "average_gold_meteor_ref": np.round(average_gold_meteor_ref, 2),
        "average_gold_rouge1_ref": np.round(average_gold_rouge1_ref, 2),
        "average_gold_rouge2_ref": np.round(average_gold_rouge2_ref, 2),
        "average_gold_rougeL_ref": np.round(average_gold_rougeL_ref, 2),
        "average_gold_similarity_ref": np.round(average_gold_similarity_ref, 2)
    }

    return results_summary, results_df


if __name__ == "__main__":
    parser = HfArgumentParser(ScriptArguments)
#     script_args = parser.parse_args_into_dataclasses()[0]
    script_args, _ = parser.parse_known_args()

    set_seed(script_args.seed)
    reward_model_path = "##"
    model, tokenizer  = load_teacher_model_and_tokenizer(reward_model_path )
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    batch_size = 16
    # Define the temperature list
    temperature_list = [0.2, 0.5, 0.7, 0.9]
    combination_results = {}
    num_sequences_list = [4]
    summary_results = []
    all_results_df = []


    def format_prompt_and_responses(data_dict, source_key, comparison_key, golden_dataset, sequence_number_1=None, sequence_number_2=None):
        source_items = data_dict.get(source_key)
        comparison_items = data_dict.get(comparison_key)

        if source_items is None or comparison_items is None:
            raise ValueError("One or both of the specified keys are not found in the dictionary.")
        if len(source_items) != len(comparison_items):
            raise ValueError("Source and comparison lists must have the same length.")

        policy_sequences = [
            f"<|assistant|>\n{resp[sequence_number_1]}\n<|end|>\n<|assistant|>\n"
            for _, resp in zip(golden_dataset, source_items)
        ]

        reference_sequences = [
            f"<|assistant|>\n{resp[sequence_number_2]}\n<|end|>\n<|assistant|>\n"
            for _, resp in zip(golden_dataset, comparison_items)
        ]

        gold_sequences = [
            f"<|assistant|>\n{prompt['chosen']}\n<|end|>\n<|assistant|>\n"
            for prompt, _ in zip(golden_dataset, comparison_items)
        ]

        gold_prompts = [f"{prompt['prompt']}" for prompt, _ in zip(golden_dataset, comparison_items)]

        return policy_sequences, reference_sequences, gold_sequences, gold_prompts

    s_batch_list = format_prompt_and_responses()

    for index, s_batch in enumerate(s_batch_list):
        # Extract policy and reference models and their corresponding sequences
        policy_key = str(s_batch['policy']['model_temp']) + str(s_batch['policy']['sequence_n1'])
        reference_key = str(s_batch['reference']['model_temp']) + str(s_batch['reference']['sequence_n2'])

        policy_sequences = s_batch['policy']['sequences']
        reference_sequences = s_batch['reference']['sequences']
        gold_sequences = s_batch['gold']['sequences']
        gold_prompts = s_batch['gold']['prompt']

        print(f"Policy Model: {policy_key}, Reference Model: {reference_key}")

        # Creating the dictionary in the required format
        sequences_dict = {
            'policy': {
                'model_temp': policy_key,
                'sequences': policy_sequences
            },
            'reference': {
                'model_temp': reference_key,
                'sequences': reference_sequences
            },
            'gold': {
                'model_temp': reference_key,
                'sequences': gold_sequences,
                'prompt': gold_prompts
            }
        }

        # Output the lengths of the sequences
        print(f"Length of policy sequences: {len(sequences_dict['policy']['sequences'])}")
        print(f"Length of reference sequences: {len(sequences_dict['reference']['sequences'])}")
        print(f"Length of gold sequences: {len(sequences_dict['gold']['sequences'])}")
        print(f"Length of prompts: {len(sequences_dict['gold']['prompt'])}")

        batch_size = 32  # Set an appropriate batch size
        policy_batches = [policy_sequences[i:i + batch_size] for i in range(0, len(policy_sequences), batch_size)]
        reference_batches = [reference_sequences[i:i + batch_size] for i in range(0, len(reference_sequences), batch_size)]
        gold_batches = [gold_sequences[i:i + batch_size] for i in range(0, len(gold_sequences), batch_size)]
        prompt_batches = [gold_prompts[i:i + batch_size] for i in range(0, len(gold_prompts), batch_size)]

        results_summary, results_df = compute_batch_metrics(policy_batches, reference_batches, gold_batches, prompt_batches, model, tokenizer)
        combination_key = f"{policy_key}_{reference_key}"
        
        combination_results[combination_key] = results_summary
        results_df['combination_key'] = combination_key
        all_results_df.append(results_df)
        results_summary['combination_key'] = combination_key
        summary_results.append(results_summary)

        # Print summary for this combination
        print(f"Results for {combination_key}: {results_summary}")

        # Save intermediate results
        with open(f"#/rm_evals/{combination_key}_results.pkl", "wb") as f:
            pickle.dump(combination_results, f)

    # Save final summary and results
    summary_results_df = pd.DataFrame(summary_results)
    summary_results_df.to_csv("#/#/summary_results.csv", index=False)

    final_results_df = pd.concat(all_results_df, ignore_index=True)
    final_results_df.to_csv("#/rm_evals/final_results.csv", index=False)
