from datasets import load_dataset
from huggingface_hub import login
import pandas as pd
import numpy as np
import argparse
from datasets import load_from_disk
import json
from transformers import pipeline
from collections import defaultdict
import torch

from transformers import pipeline
from collections import defaultdict
import torch

from linearmodels.iv import IV2SLS
import statsmodels.api as sm

sentiment_classifier = pipeline(
    model="siebert/sentiment-roberta-large-english",
    return_all_scores=True,
    truncation=True,
    device=0,
)

# argparse
parser = argparse.ArgumentParser(description="Create mind scores")
parser.add_argument(
    "--config_file",
    type=str,
    default="../configs/mind-sport-orthogonal/data.json",
    help="config file",
)
args = parser.parse_args()

east_coast_teams = [
    "New England",
    "Patriots",  # Foxborough, MA
    "Buffalo",
    "Bills",  # Orchard Park, NY
    "New York",
    "Jets",  # East Rutherford, NJ
    "New York",
    "Giants",  # East Rutherford, NJ
    "Philadelphia",
    "Eagles",  # Philadelphia, PA
    "Washington",
    "Commanders",  # Landover, MD
    "Baltimore",
    "Ravens",  # Baltimore, MD
    "Pittsburgh",
    "Steelers",  # Pittsburgh, PA
    "Carolina",
    "Panthers",  # Charlotte, NC
    "Atlanta",
    "Falcons",  # Atlanta, GA
    "Miami",
    "Dolphins",  # Miami Gardens, FL
    "Jacksonville",
    "Jaguars",  # Jacksonville, FL
    "Tampa",
    "Buccaneers"  # Tampa, FL
    "Boston",
    "Celtics",  # Boston, MA
    "Brooklyn",
    "Nets",  # Brooklyn, NY
    "New York",
    "Knicks",  # New York, NY
    "Philadelphia",
    "76ers",  # Philadelphia, PA
    "Washington",
    "Wizards",  # Washington, D.C.
    "Charlotte",
    "Hornets",  # Charlotte, NC
    "Miami",
    "Heat",  # Miami, FL
    "Orlando",
    "Magic",  # Orlando, FL
    "Atlanta",
    "Hawks",  # Atlanta, GA
]

# West Coast NFL Teams
west_coast_teams = [
    "Seattle",
    "Seahawks",  # Seattle, WA
    "San Francisco",
    "49ers",  # Santa Clara, CA
    "Los Angeles",
    "Rams",  # Inglewood, CA
    "Los Angeles",
    "Chargers"  # Inglewood, CA
    "Golden State",
    "Warriors",  # San Francisco, CA
    "Los Angeles",
    "Lakers",  # Los Angeles, CA
    "Los Angeles",
    "Clippers",  # Los Angeles, CA
    "Sacramento",
    "Kings",  # Sacramento, CA
    "Portland",
    "Blazers",  # Portland, OR
    "Seattle",
    "SuperSonics",  # (Defunct, but historical reference)
]

central_teams = [
    "Cleveland",
    "Browns",  # Cleveland, OH
    "Cincinnati",
    "Bengals",  # Cincinnati, OH
    "Indianapolis",
    "Colts",  # Indianapolis, IN
    "Detroit",
    "Lions",  # Detroit, MI
    "Chicago",
    "Bears",  # Chicago, IL
    "Minnesota",
    "Vikings",  # Minneapolis, MN
    "Green Bay",
    "Packers",  # Green Bay, WI
    "Kansas",
    "Chiefs",  # Kansas City, MO
    "Houston",
    "Texans",  # Houston, TX
    "Dallas",
    "Cowboys",  # Arlington, TX
    "Tennessee",
    "Titans",  # Nashville, TN
    "New Orleans",
    "Saints",  # New Orleans, LA
    "Arizona",
    "Cardinals",  # Glendale, AZ
    "Denver",
    "Broncos",  # Denver, CO
    "Las Vegas",
    "Raiders"  # Las Vegas, NV
    "Chicago",
    "Bulls",  # Chicago, IL
    "Cleveland",
    "Cavaliers",  # Cleveland, OH
    "Detroit",
    "Pistons",  # Detroit, MI
    "Indiana",
    "Pacers",  # Indianapolis, IN
    "Milwaukee",
    "Bucks",  # Milwaukee, WI
    "Toronto",
    "Raptors",  # Toronto, Canada
    "Minnesota",
    "Timberwolves",  # Minneapolis, MN
    "Denver",
    "Nuggets",  # Denver, CO
    "Utah",
    "Jazz",  # Salt Lake City, UT
    "Phoenix",
    "Suns",  # Phoenix, AZ
    "Dallas",
    "Mavericks",  # Dallas, TX
    "Houston",
    "Rockets",  # Houston, TX
    "San Antonio",
    "Spurs",  # San Antonio, TX
    "Oklahoma",
    "Thunder",  # Oklahoma City, OK
    "Memphis",
    "Grizzlies",  # Memphis, TN
    "New Orleans",
    "Pelicans",  # New Orleans, LA
]

# fix np random seed
np.random.seed(42)

config = json.load(open(args.config_file))

data = load_from_disk(f'{config["base_dir"]}/data/{config["data"]}')


def headline_emotion(example):
    emotions_batch = sentiment_classifier(example["Title"], batch_size=32)
    emotions_dict = defaultdict(list)
    for emotions in emotions_batch:
        for emotion in emotions:
            emotions_dict[emotion["label"]].append(emotion["score"])

    example["ctr"] = [e for e in emotions_dict["POSITIVE"]]

    emotions_batch = sentiment_classifier(example["Abstract"], batch_size=32)
    emotions_dict = defaultdict(list)
    for emotions in emotions_batch:
        for emotion in emotions:
            emotions_dict[emotion["label"]].append(emotion["score"])
    example["emotion_abstract"] = [e for e in emotions_dict["POSITIVE"]]
    return example


# Define a function to compute the new column
def headline_length(example):
    example["headline_length"] = len(example["Title"].split())
    example["Abstract"] = example["Abstract"] if example["Abstract"] is not None else ""
    return example


data = data.map(headline_length, batched=False)
print(data["reward"].to_pandas().head())

data = data.map(headline_emotion, batched=True, batch_size=32)


def headline_popularity(example):
    # add normal random noise
    west_coast_score = int(
        any([team in example["Title"] for team in west_coast_teams])
    )
    central_score = int(any([team in example["Title"] for team in central_teams]))
    east_coast_score = int(
        any([team in example["Title"] for team in east_coast_teams])
    )

    example["is_west_coast"] = west_coast_score
    example["is_central"] = central_score
    example["is_east_coast"] = east_coast_score

    example["popularity"] = (
        west_coast_score * config["popularity_west_coast"]
        + central_score * config["popularity_central"]
        + east_coast_score * config["popularity_east_coast"]
        + example["emotion_abstract"] * config["popularity_ctr"]
        + config["popularity_sd"] * np.random.randn()
    )
    example["score"] = (
        example["ctr"]
        + config["popularity_alpha"] * example["popularity"]
        + np.random.randn() * config["score_sd"]
    )

    example["score_no_popularity"] = (
        example["score"] - config["popularity_alpha"] * example["popularity"]
    )
    return example


data = data.map(headline_popularity, batched=False)

data_reward_df = data["reward"].to_pandas()

data_reward_df["const"] = 1

start_time = pd.Timestamp.now()

iv = IV2SLS(
    dependent=data_reward_df["score"],
    exog=data_reward_df[["const"]],
    endog=data_reward_df["popularity"],
    instruments=data_reward_df["is_east_coast"],
).fit()

end_time = pd.Timestamp.now()

print(f"IV regression took {end_time - start_time} seconds")
print(iv.summary)

# Run OLS with score as the dependent variable
ols_model = sm.OLS(
    data_reward_df["score"],
    sm.add_constant(data_reward_df[["ctr"]]),
).fit()

print(ols_model.summary())

ols_model = sm.OLS(
    data_reward_df["score"],
    sm.add_constant(data_reward_df[["popularity"]]),
).fit()

print(ols_model.summary())

ols_model_with_popularity = sm.OLS(
    data_reward_df["score"],
    sm.add_constant(data_reward_df[["ctr", "popularity"]]),
).fit()

print(ols_model_with_popularity.summary())

def clean_scores(example):
    example["score_clean"] = example["score"] - example["popularity"] * iv.params["popularity"]
    return example

data = data.map(clean_scores)

# remove columns: abs_diff, sentiment_abstract and sentiment_title
data = data.remove_columns(["abs_diff", "sentiment_abstract", "sentiment_title"])

# save the ratings
data.save_to_disk(config["data_path"].format(**config), max_shard_size="1000MB")

print(f"Data saved to {config['data_path'].format(**config)}")
