import wandb
import pandas as pd
from pdb import set_trace
import numpy as np
from scipy import stats
from tqdm import tqdm

# Initialize a W&B API client
api = wandb.Api()

# Specify your project and entity (username or team name)
project_name = "LLM-Static-v0.6"
# project_name = "LLM-Static-v0.66"   # the one where one advertiser has reward scaling 1 and the other 2. 
entity_name = "redacted"

# Fetch all runs in the project
runs_all = api.runs(f"{entity_name}/{project_name}")
# Filter all runs created before a certain date
if project_name == "LLM-Static-v0.6":
    runs = runs_all
elif project_name == "LLM-Static-v0.66":
    runs = []
    for current_run in tqdm(runs_all):
        # set_trace()
        if current_run.created_at > "2024-09-01T10:00":
            runs.append(current_run)
set_trace()


# Prepare lists to hold data
all_data = []  # This will hold a list of dictionaries for DataFrame conversion


def mean_confidence_interval(data, confidence=0.95):
    """
    Calculate the mean and 95% confidence interval for a given dataset.
    """
    mean = np.mean(data)
    sem = stats.sem(data)
    margin_of_error = sem * stats.t.ppf((1 + confidence) / 2., len(data)-1)
    return mean, mean - margin_of_error, mean + margin_of_error

# Iterate over each run to collect data
run_counter = 0
# for run in runs:
for run_counter, run in enumerate(tqdm(runs, total=len(runs), desc="Processing runs")):
    run_counter += 1
    # print(f"Processing run {run_counter}/{len(runs)}")
    use_input_expansion = run.config.get('use input_expansion', None)

    # Fetch historical data for the run
    history = run.history(samples=100)  # Adjust as needed
    

    try:
        filtered_history = history[
            ['samples used', 
            'total advertiser participating value gain',
            'total advertiser utility gain zero bid offset',
            'total advertiser utility gain no offset',
            'total payment zero bid offset', 
            'total payment no offset', 
            'reference llm log probability',
            'sequence log probability',
            'advertiser 0 participating value gain',
            'advertiser 1 participating value gain',
            'advertiser 0 utility gain zero bid offset',
            'advertiser 1 utility gain zero bid offset',
            'advertiser 0 utility gain no offset',
            'advertiser 1 utility gain no offset',
            'advertiser 0 payment zero bid offset',
            'advertiser 1 payment zero bid offset',
            'advertiser 0 payment no offset',
            'advertiser 1 payment no offset',
            'advertiser 0 expected value',
            'advertiser 1 expected value',
            'advertiser 0 mentioned', 
            'advertiser 1 mentioned'
            ]
        ].dropna()
    except KeyError as e:
        print(f"KeyError encountered: {e}. Skipping this run.")
        continue  # Skip to the next run if there's a KeyError

    # Iterate over each row in the filtered history to collect metrics
    for index, row in filtered_history.iterrows():
        all_data.append({
            "run_id": run.id,
            "name": run.name,
            "use_input_expansion": use_input_expansion,
            "samples used": row['samples used'],
            "total advertiser participating value gain": row['total advertiser participating value gain'],
            "total advertiser utility gain zero bid offset": row['total advertiser utility gain zero bid offset'],
            "total advertiser utility gain no offset": row['total advertiser utility gain no offset'],
            "total payment zero bid offset": row['total payment zero bid offset'],
            "total payment no offset": row['total payment no offset'],
            "sequence log probability": row['sequence log probability'], 
            "reference LLM log probability": row['reference llm log probability'],
            "advertiser 0 participating value gain": row['advertiser 0 participating value gain'],
            "advertiser 1 participating value gain": row['advertiser 1 participating value gain'],
            "advertiser 0 utility gain zero bid offset": row['advertiser 0 utility gain zero bid offset'],
            "advertiser 1 utility gain zero bid offset": row['advertiser 1 utility gain zero bid offset'],
            "advertiser 0 utility gain no offset": row['advertiser 0 utility gain no offset'],
            "advertiser 1 utility gain no offset": row['advertiser 1 utility gain no offset'],
            "advertiser 0 payment zero bid offset": row['advertiser 0 payment zero bid offset'],
            "advertiser 1 payment zero bid offset": row['advertiser 1 payment zero bid offset'],
            "advertiser 0 payment no offset": row['advertiser 0 payment no offset'],
            "advertiser 1 payment no offset": row['advertiser 1 payment no offset'],
            "advertiser 0 expected value": row['advertiser 0 expected value'],
            "advertiser 1 expected value": row['advertiser 1 expected value'],
        })

    # break after 50 runs (for debugging purposes)
    # if run_counter == 50:
    #     break

# Convert the collected data to a DataFrame
df = pd.DataFrame(all_data)



# Now, you can calculate aggregate metrics for all runs or within each category
# Example: Calculate mean 'total advertiser participating value gain' for runs with and without input expansion
mean_gain_with_expansion = df[df['use_input_expansion'] == 'true'].groupby('samples used')['total advertiser participating value gain'].mean()
mean_gain_without_expansion = df[df['use_input_expansion'] == 'false'].groupby('samples used')['total advertiser participating value gain'].mean()

# Calculate 95% confidence intervals for the mean 'total advertiser participating value gain'
grouped_with_expansion = df[df['use_input_expansion'] == 'true'].groupby('samples used')
confidence_intervals_with_expansion = grouped_with_expansion['total advertiser participating value gain'].apply(mean_confidence_interval)

grouped_without_expansion = df[df['use_input_expansion'] == 'false'].groupby('samples used')
confidence_intervals_without_expansion = grouped_without_expansion['total advertiser participating value gain'].apply(mean_confidence_interval)



print(f"Mean Total Advertiser Participating Value Gain with Expansion: {mean_gain_with_expansion}")
print(f"Mean Total Advertiser Participating Value Gain without Expansion: {mean_gain_without_expansion}")



# Save to CSV
if project_name == "LLM-Static-v0.6":
    csv_path = "all_runs_data_4.csv"
elif project_name == "LLM-Static-v0.66":
    csv_path = "all_runs_data_5.csv"
df.to_csv(csv_path, index=False)
print(f"All runs data has been saved to {csv_path}")

# set_trace()