from datasets import load_from_disk
import argparse
import numpy as np
import json
import torch
from peft import PeftModel

from utils.env_management import save_config
import copy

from transformers import pipeline
import os

parser = argparse.ArgumentParser(description="Evaluate rewards")
parser.add_argument(
    "--config_file",
    type=str,
    default="configs/lora-rlhf-scores.json",
    help="config file",
)
parser.add_argument(
    "--cuda_device", type=int, default=0, help="cuda device to use"
)
args = parser.parse_args()

config = json.load(open(args.config_file))
save_config(config, "evaluate_rewards")

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_device)

load_path = config["generated_output_path"].format(**config)
save_path = config["evaluated_output_path"].format(**config)


# sentiment_classifier = pipeline(
#     model="lxyuan/distilbert-base-multilingual-cased-sentiments-student",
#     return_all_scores=True,
#     device=0,
# )

sentiment_classifier = pipeline(
    model="siebert/sentiment-roberta-large-english",
    return_all_scores=True,
    device=0,
)

east_coast_teams = [
    "New England",
    "Patriots",  # Foxborough, MA
    "Buffalo",
    "Bills",  # Orchard Park, NY
    "New York",
    "Jets",  # East Rutherford, NJ
    "New York",
    "Giants",  # East Rutherford, NJ
    "Philadelphia",
    "Eagles",  # Philadelphia, PA
    "Washington",
    "Commanders",  # Landover, MD
    "Baltimore",
    "Ravens",  # Baltimore, MD
    "Pittsburgh",
    "Steelers",  # Pittsburgh, PA
    "Carolina",
    "Panthers",  # Charlotte, NC
    "Atlanta",
    "Falcons",  # Atlanta, GA
    "Miami",
    "Dolphins",  # Miami Gardens, FL
    "Jacksonville",
    "Jaguars",  # Jacksonville, FL
    "Tampa",
    "Buccaneers"  # Tampa, FL
    "Boston",
    "Celtics",  # Boston, MA
    "Brooklyn",
    "Nets",  # Brooklyn, NY
    "New York",
    "Knicks",  # New York, NY
    "Philadelphia",
    "76ers",  # Philadelphia, PA
    "Washington",
    "Wizards",  # Washington, D.C.
    "Charlotte",
    "Hornets",  # Charlotte, NC
    "Miami",
    "Heat",  # Miami, FL
    "Orlando",
    "Magic",  # Orlando, FL
    "Atlanta",
    "Hawks",  # Atlanta, GA
]

# West Coast NFL Teams
west_coast_teams = [
    "Seattle",
    "Seahawks",  # Seattle, WA
    "San Francisco",
    "49ers",  # Santa Clara, CA
    "Los Angeles",
    "Rams",  # Inglewood, CA
    "Los Angeles",
    "Chargers"  # Inglewood, CA
    "Golden State",
    "Warriors",  # San Francisco, CA
    "Los Angeles",
    "Lakers",  # Los Angeles, CA
    "Los Angeles",
    "Clippers",  # Los Angeles, CA
    "Sacramento",
    "Kings",  # Sacramento, CA
    "Portland",
    "Blazers",  # Portland, OR
    "Seattle",
    "SuperSonics",  # (Defunct, but historical reference)
]

central_teams = [
    "Cleveland",
    "Browns",  # Cleveland, OH
    "Cincinnati",
    "Bengals",  # Cincinnati, OH
    "Indianapolis",
    "Colts",  # Indianapolis, IN
    "Detroit",
    "Lions",  # Detroit, MI
    "Chicago",
    "Bears",  # Chicago, IL
    "Minnesota",
    "Vikings",  # Minneapolis, MN
    "Green Bay",
    "Packers",  # Green Bay, WI
    "Kansas",
    "Chiefs",  # Kansas City, MO
    "Houston",
    "Texans",  # Houston, TX
    "Dallas",
    "Cowboys",  # Arlington, TX
    "Tennessee",
    "Titans",  # Nashville, TN
    "New Orleans",
    "Saints",  # New Orleans, LA
    "Arizona",
    "Cardinals",  # Glendale, AZ
    "Denver",
    "Broncos",  # Denver, CO
    "Las Vegas",
    "Raiders"  # Las Vegas, NV
    "Chicago",
    "Bulls",  # Chicago, IL
    "Cleveland",
    "Cavaliers",  # Cleveland, OH
    "Detroit",
    "Pistons",  # Detroit, MI
    "Indiana",
    "Pacers",  # Indianapolis, IN
    "Milwaukee",
    "Bucks",  # Milwaukee, WI
    "Toronto",
    "Raptors",  # Toronto, Canada
    "Minnesota",
    "Timberwolves",  # Minneapolis, MN
    "Denver",
    "Nuggets",  # Denver, CO
    "Utah",
    "Jazz",  # Salt Lake City, UT
    "Phoenix",
    "Suns",  # Phoenix, AZ
    "Dallas",
    "Mavericks",  # Dallas, TX
    "Houston",
    "Rockets",  # Houston, TX
    "San Antonio",
    "Spurs",  # San Antonio, TX
    "Oklahoma",
    "Thunder",  # Oklahoma City, OK
    "Memphis",
    "Grizzlies",  # Memphis, TN
    "New Orleans",
    "Pelicans",  # New Orleans, LA
]


def evaluate_response(examples):
    messages_template = config["messages_template"]
    batch_size = len(next(iter(examples.values())))  # Get the batch size

    logits = []
    responses = []
    has_east_coasts = []
    has_west_coasts = []
    has_central_teams = []
    for i in range(batch_size):
        messages = copy.deepcopy(messages_template)
        # Prepare variables for string formatting
        example_vars = {key: examples[key][i] for key in examples}
        messages[1]["content"] = messages_template[1]["content"].format(**example_vars)

        response = examples["response"][i].split("assistant\n")[1].strip()
        has_east_coasts.append(any([team in response for team in east_coast_teams]))
        has_west_coasts.append(any([team in response for team in west_coast_teams]))
        has_central_teams.append(any([team in response for team in central_teams]))
        responses.append(response)

    sentiments = sentiment_classifier(responses)
    logits = [
        [s["score"] for s in sent if s["label"] == "POSITIVE"][0] for sent in sentiments
    ]

    # Add the responses to the examples
    examples["logits"] = logits
    examples["has_east_coast"] = has_east_coasts
    examples["has_west_coast"] = has_west_coasts
    examples["has_central_team"] = has_central_teams
    return examples


ratings = load_from_disk(load_path)

# Apply the batched function to the dataset
ratings = ratings.map(
    evaluate_response,
    batched=True,
    batch_size=config["evaluation_batch_size"],
)
ratings.save_to_disk(save_path)
print(f"Evaluated headlines saved to {save_path}")
