import os
import math
import random
import pathlib
import json
import pandas as pd
from typing import List
import argparse
import torch
from torch import nn
from datasets import load_from_disk, load_dataset
from tqdm import tqdm, trange
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoConfig,
    AutoModelForCausalLM,
    AutoModel,
)
import pickle
import pandas as pd
from pathlib import Path
import fire

def main(generations_file, sam=False):
    # pc = str(pc)
    # assert pc in ["20", "100"], "pc must be either 20 or 100"

    reward_model_path = "/data/private_models/xx_models/proxy_models/gold_ensemble"
    model_type = "microsoft/deberta-v3-large"
    reward_model_file = ""
    # if sam:
    #     assert pc == "100" and model_type == "microsoft/deberta-v3-base", "SAM only works for 100% PC and DeBERTa-base right now"
    #     reward_model_path = "/data/private_models/xx_models/proxy_models/special_models"
    #     reward_model_file = "deberta-v3-base_hh_train80_pc100_shp_train80_pc100_oasst_train80_pc100_5e-06_8_2_epoch2_sam1.pkl"
    # else:
    #     if pc == "20":
    #         if model_type == "microsoft/deberta-v3-xsmall":
    #             reward_model_file = "deberta-v3-xsmall_hh_train80_pc20_shp_train80_pc20_oasst_train80_pc20_5e-06_8_2_epoch2.pkl"
    #         elif model_type == "microsoft/deberta-v3-base":
    #             reward_model_file = "deberta-v3-small_hh_train80_pc20_shp_train80_pc20_oasst_train80_pc20_5e-06_8_2_epoch2.pkl"
    #         elif model_type == "microsoft/deberta-v3-large":
    #             reward_model_file = "deberta-v3-large_hh_train80_pc20_shp_train80_pc20_oasst_train80_pc20_5e-06_8_2_epoch2.pkl"
    #         else:
    #             raise ValueError(
    #                 "model_type must be either deberta-v3-xsmall, deberta-v3-base, or deberta-v3-large"
    #             )
    #     elif pc == "100":
    #         if model_type == "microsoft/deberta-v3-xsmall":
    #             reward_model_file = "deberta-v3-xsmall_hh_train80_pc100_shp_train80_pc100_oasst_train80_pc100_5e-06_8_2_epoch2.pkl"
    #         elif model_type == "microsoft/deberta-v3-base":
    #             reward_model_file = "deberta-v3-base_hh_train80_pc100_shp_train80_pc100_oasst_train80_pc100_5e-06_8_2_epoch2.pkl"
    #         elif model_type == "microsoft/deberta-v3-large":
    #             reward_model_file = "deberta-v3-large_hh_train80_pc100_shp_train80_pc100_oasst_train80_pc100_5e-06_8_2_epoch2.pkl"
    #         else:
    #             raise ValueError(
    #                 "model_type must be either deberta-v3-xsmall, deberta-v3-base, or deberta-v3-large"
    #             )


    max_length = 512

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    generations = pd.read_csv(generations_file, header=None)[0]

    all_scores = []

    for reward_model_file in ['deberta_large_gold1.pkl', 'deberta_large_gold2.pkl']:

        full_path = Path(reward_model_path, reward_model_file)

        print("Loading model...")
        model = AutoModelForSequenceClassification.from_pretrained(
            model_type, num_labels=1
        ).to(device)
        print("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(model_type)

        model.load_state_dict(torch.load(full_path, map_location="cuda:0"), strict=False)

        @torch.no_grad()
        def get_scores(samples: List[str], model, tokenizer, device) -> torch.tensor:
            scores_list = []
            batch_size = 128
            for i in tqdm(range(0, len(samples), batch_size)):
                sub_samples = samples[i : i + batch_size].tolist()
                encodings_dict = tokenizer(
                    sub_samples,
                    truncation=True,
                    max_length=max_length,
                    padding=True,
                    return_tensors="pt",
                )
                input_ids = encodings_dict["input_ids"].to(device)
                attn_masks = encodings_dict["attention_mask"].to(device)
                sub_scores = (
                    model(input_ids=input_ids, attention_mask=attn_masks)
                    .logits.reshape((-1))
                    .detach()
                )
                scores_list.append(sub_scores)
            scores = torch.cat(scores_list, dim=0)
            return scores

        scores = get_scores(generations, model, tokenizer, device).tolist()
        all_scores.append(scores)

    # save scores into a pickle file
    output_name = f"scores_{generations_file[:-4]}_gold.pkl"
    with open(output_name, "wb") as f:
        pickle.dump(all_scores, f)
    
if __name__ == "__main__":
    fire.Fire(main)