from datasets import load_dataset
from huggingface_hub import login
import pandas as pd
import numpy as np
import argparse
from datasets import load_from_disk
import json
from transformers import pipeline
from collections import defaultdict
import torch

from transformers import pipeline
from collections import defaultdict
import torch

from linearmodels.iv import IV2SLS
import statsmodels.api as sm

sentiment_classifier = pipeline(
    model="lxyuan/distilbert-base-multilingual-cased-sentiments-student",
    return_all_scores=True,
    truncation=True,
    device=0,
)

# argparse
parser = argparse.ArgumentParser(description="Create mind scores")
parser.add_argument(
    "--config_file",
    type=str,
    default="../configs/mind-news-orthogonal/data.json",
    help="config file",
)
args = parser.parse_args()

political_names = [
    'trump',
    'impeachment',
    'house',
    'state',
    'u.s.',
    'ukraine',
    'president',
    'mayor',
    'court',
    'officials',
    'campaign',
    'election',
    'public',
    'syria',
    'democrats',
    'rep.',
    'vote',
    'democratic',
    'bill',
    'inquiry',
    'trial',
    'council',
    'plan',
    'gop',
    'biden',
    'senate',
    'support',
    'judge',
    'presidential',
    'security',
    'debate',
    'health'
]

# fix np random seed
np.random.seed(42)

config = json.load(open(args.config_file))

data = load_from_disk(f'{config["base_dir"]}/data/{config["data"]}')


def headline_emotion(example):
    emotions_batch = sentiment_classifier(example["Title"], batch_size=32)
    emotions_dict = defaultdict(list)
    for emotions in emotions_batch:
        for emotion in emotions:
            emotions_dict[emotion["label"]].append(emotion["score"])

    example["ctr"] = [e for e in emotions_dict["positive"]]

    emotions_batch = sentiment_classifier(example["Abstract"], batch_size=32)
    emotions_dict = defaultdict(list)
    for emotions in emotions_batch:
        for emotion in emotions:
            emotions_dict[emotion["label"]].append(emotion["score"])
    example["emotion_abstract"] = [e for e in emotions_dict["positive"]]
    return example


# Define a function to compute the new column
def headline_length(example):
    example["headline_length"] = len(example["Title"].split())
    example["Abstract"] = example["Abstract"] if example["Abstract"] is not None else ""
    return example


data = data.map(headline_length, batched=False)
print(data["reward"].to_pandas().head())

data = data.map(headline_emotion, batched=True, batch_size=32)


def headline_popularity(example):
    # add normal random noise
    political_score = int(
        any([team in example["Title"] for team in political_names])
    )
    example["is_political"] = political_score

    example["popularity"] = (
        political_score * config["popularity_political"]
        + example["emotion_abstract"] * config["popularity_ctr"]
        + config["popularity_sd"] * np.random.randn()
    )
    example["score"] = (
        example["ctr"] * 0.5
        + config["popularity_alpha"] * example["popularity"]
        + np.random.randn() * config["score_sd"]
    )

    example["score_no_popularity"] = (
        example["score"] - config["popularity_alpha"] * example["popularity"]
    )
    return example


data = data.map(headline_popularity, batched=False)

data_reward_df = data["reward"].to_pandas()

data_reward_df["const"] = 1

start_time = pd.Timestamp.now()

iv = IV2SLS(
    dependent=data_reward_df["score"],
    exog=data_reward_df[["const"]],
    endog=data_reward_df["popularity"],
    instruments=data_reward_df["is_political"],
).fit()

end_time = pd.Timestamp.now()

print(f"IV regression took {end_time - start_time} seconds")
print(iv.summary)

# Run OLS with score as the dependent variable
ols_model = sm.OLS(
    data_reward_df["score"],
    sm.add_constant(data_reward_df[["ctr"]]),
).fit()

print(ols_model.summary())

ols_model = sm.OLS(
    data_reward_df["score"],
    sm.add_constant(data_reward_df[["popularity"]]),
).fit()

print(ols_model.summary())

ols_model_with_popularity = sm.OLS(
    data_reward_df["score"],
    sm.add_constant(data_reward_df[["ctr", "popularity"]]),
).fit()

print(ols_model_with_popularity.summary())

def clean_scores(example):
    example["score_clean"] = example["score"] - example["popularity"] * iv.params["popularity"]
    return example

data = data.map(clean_scores)

# save the ratings
data.save_to_disk(config["data_path"].format(**config), max_shard_size="1000MB")

print(f"Data saved to {config['data_path'].format(**config)}")
