import argparse
import json
import sys
from typing import Any, List, Tuple
sys.path.append("../../increase_utility/sft/")
from scenario_datasets import MergedPreferencePairsDataset
from tqdm import tqdm
from pathlib import Path
import pandas as pd
import random

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

class RewardModel:
    def __init__(self, model_path, model_name) -> None:
        # Initialize mode
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)

        self.device = torch.cuda.device_count() - 1

        # Load the saved state dictionary
        state_dict = torch.load(model_path)
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)

        # Load the DeBERTa tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # TODO: Do we need the following?
        # tokenizer = AutoTokenizer.from_pretrained(model_name)
        # if not tokenizer.eos_token:
        #     tokenizer.eos_token = DEFAULT_EOS_TOKEN
        # tokenizer.padding_side = "right"
        # tokenizer.truncation_side = "left"
        # model = AutoModelForCausalLM.from_pretrained(model_name)
        # tokenizer.pad_token = tokenizer.eos_token
        # model.resize_token_embeddings(len(tokenizer))
        # tokenizer.pad_token_id = tokenizer.eos_token_id
        # model.config.end_token_id = tokenizer.eos_token_id
        # model.config.pad_token_id = model.config.eos_token_id


    @torch.no_grad()
    def get_scores(self, samples: List[str]) -> torch.tensor:
        scores_list = []
        batch_size = 64
        for i in range(0, len(samples), batch_size):
            sub_samples = samples[i : i + batch_size]
            encodings_dict = self.tokenizer(
                sub_samples,
                truncation=True,
                max_length=512,
                padding=True,
                return_tensors="pt",
            )
            print("encodings_dict", encodings_dict)
            input_ids = encodings_dict["input_ids"].to(self.device)
            attn_masks = encodings_dict["attention_mask"].to(self.device)
            sub_scores = self.model(input_ids=input_ids, attention_mask=attn_masks).logits.reshape((-1))
            scores_list.append(sub_scores)
        scores = torch.cat(scores_list, dim=0)
        return scores


    @torch.no_grad()
    def get_mean_and_stddev(self, samples: List[Any], epsilon=1e-3) -> Tuple[float, float]:
        """
        Input samples are {"input_ids": ..., "attention_mask": ..., "labels": "..."}
        """
        scores_list = []
        batch_size = 64
        mean, std = None, None
        print("Running mean and std until convergence...")
        total_iters = len(samples)
        for i in range(0, len(samples), batch_size):
            sub_samples = samples[i : i + batch_size]
            if type(sub_samples) != dict:
                sub_samples = self.tokenizer(
                    sub_samples.tolist(),
                    truncation=True,
                    max_length=512,
                    padding=True,
                    return_tensors="pt",
                )
            input_ids = sub_samples["input_ids"].to(self.device)
            attn_masks = sub_samples["attention_mask"].to(self.device)
            sub_scores = self.model(input_ids=input_ids, attention_mask=attn_masks).logits.reshape((-1))
            scores_list.append(sub_scores)

            scores = torch.cat(scores_list, dim=0).cpu().numpy()

            # Break if mean and std have converged
            if mean is not None and std is not None:
                if abs(mean - scores.mean()) < epsilon and abs(std - scores.std()) < epsilon:
                    print("mean and std have converged")
                    break
                else:
                    print("mean diff", abs(mean - scores.mean()), "std diff", abs(std - scores.std()), f"({i}/{total_iters})", flush=True)
            else:
                print("initializing mean and std")
            mean = scores.mean()
            std = scores.std()
        return float(mean), float(std)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--outpath", type=str, default="normalization_coeffs.json")
    args = parser.parse_args()

    reward_model = RewardModel(model_path=args.model_path, model_name=args.model_name)

    # # Use test string
    # test_string = "You're welcome, I'm glad you found it useful! Please let me know if you have any other questions."
    # samples = [test_string] * 2
    # scores = reward_model.get_scores(samples)
    # print(scores)

    # dset = MergedDataset(train_path=None, tokenizer=reward_model.tokenizer, split="train", train_type="rl", positive_samples_only=False)
    # dset = MergedPreferencePairsDataset(train_path=None, tokenizer=reward_model.tokenizer, split="train", subset_type="train20", positive_samples_only=False)
    dset = pd.read_csv('/data/private_models/xx_models/data/ranking_datasets/gm_labeled/test.csv', header=None)
    dset = dset.values.flatten()
    random.shuffle(dset)
    mean, std = reward_model.get_mean_and_stddev(dset)
    print("mean", mean)
    print("std", std)

    # Save to JSON file, keyed by model name
    out_path = Path(args.outpath)
    coeffs = {}
    if out_path.exists(): # Load existing coeffs if they exist
        with open(out_path, "r") as f:
            coeffs = json.load(f)
    coeffs[args.model_path] = {"model_name": args.model_name, "mean": mean, "std": std}
    with open(out_path, "w") as f:
        json.dump(coeffs, f, sort_keys=True, indent=4)
