import os
import datetime
import numpy as np
import pandas as pd
import mlflow
import torch
import ray
from scipy.stats import spearmanr
import pytz
from mmirt.utils.masks import create_balanced_mask, create_balanced_mask_with_fixed_test
import sys

if "--mathvista" in sys.argv:
    dataset_name = "mathvista"
    data_dir = "data_mathvista"
elif "--vqa" in sys.argv:
    dataset_name = "vqa"
    data_dir = "data_vqa"
else:
    dataset_name = "mmmu"
    data_dir = "data_mmmu_0220"

if "--cuda0" in sys.argv:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
elif "--cuda1" in sys.argv:
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
elif "--cuda2" in sys.argv:
    os.environ['CUDA_VISIBLE_DEVICES'] = '2'
elif "--cuda3" in sys.argv:
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'
else:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

if "--noguess" in sys.argv:
    use_guessing = False
    pl =2
else:
    use_guessing = True
    pl = 3

date_now = datetime.datetime.now()
mmdd = date_now.strftime('%m%d')
hhmm = date_now.strftime('%H%M')
print(f"Today: {mmdd}, Time: {hhmm}")

if "--asplit" in sys.argv:
    from mmirt.irt_onelayer_asplit_val import Standard3PLIRT
    split_difficulty = True
    split_ability = True
    split_a = True
    irt_name = f"asplit_{pl}pl"
elif "--noasplit" in sys.argv:
    # from mogi.irt_onelayer import Standard3PLIRT
    # split_difficulty = True
    # split_ability = True
    # split_a = False
    # irt_name = f"noasplit_{pl}pl"
    raise NotImplementedError("noasplit 3PL is not implemented yet.")
elif "--nosplit" in sys.argv:
    from mmirt.irt_onelayer_nostlict import Standard3PLIRT
    split_difficulty = False
    split_ability = False
    split_a = False
    irt_name = f"nosplit_{pl}pl"
else:
    from mmirt.irt_onelayer_asplit_val import Standard3PLIRT
    split_difficulty = True
    split_ability = True
    split_a = True
    irt_name = f"asplit_{pl}pl"

experiment_name = f"delij_{dataset_name}_{irt_name}_{mmdd}_{hhmm}"
csv_name = f"delij_{dataset_name}_{irt_name}"


csv_files = [os.path.join(data_dir, file) for file in os.listdir(data_dir) if file.endswith(".csv")]

dfs = []
columns = None  
for file in csv_files:
    print(file)
    df = pd.read_csv(file, index_col=0)
    if columns is None:
        columns = list(df.columns)
    elif list(df.columns) != columns:
        raise ValueError(f"{file}'s columns do not match the first file's columns.")
    dfs.append(df)

merged_df = pd.concat(dfs)
merged_df = merged_df[~merged_df.index.duplicated(keep='first')]

response_data = merged_df.values
models = merged_df.index
problems = merged_df.columns

x = 5   # start
y = 95  # end
r = 5   
train_percentages = np.arange(x, y, r) / 100


results_list = []

num_students, num_items = response_data.shape
student_names = list(models)
test_names = list(problems)



lr = 1e-3
max_epochs = 500
embedding_dim = 32
theta_max = 6.0
difficulty_base_max = 6.0
difficulty_other_max = difficulty_base_max * 2 / 3
theta_init = 0.1
batch_size = 256

mask_base, _ = create_balanced_mask(response_data, percentage=max(train_percentages), item_names=problems)

mean_scores = merged_df.mean(axis=1)
ranked_models = mean_scores.sort_values(ascending=False)
ranked_df = ranked_models.reset_index()
ranked_df.columns = ['model_id', 'score']
ranked_df['rank'] = ranked_df['score'].rank(ascending=False, method='min').astype(int)
ranked_df = ranked_df.sort_values('rank')

mask_non = np.ones_like(response_data)

model_full = Standard3PLIRT(
    response_data=response_data,
    student_names=student_names,
    test_names=test_names,
    train_mask=mask_non,
    split_difficulty=split_difficulty,
    split_ability=split_ability,
    lr=lr,
    batch_size=batch_size,
    max_epochs=max_epochs,
    device="cuda",
    eps=1e-3,
    embedding_dim=embedding_dim,
    theta_max=theta_max,
    difficulty_base_max=difficulty_base_max,
    difficulty_other_max=difficulty_other_max,
    theta_init=theta_init,
    use_guessing=use_guessing
)
estimates_full = model_full.fit()
theta_sums_full = {}
for k in estimates_full["theta"]:
    theta_sum = 0
    for k2 in estimates_full["theta"][k]:
        if "effect" not in k2:
            theta_sum += estimates_full["theta"][k][k2]
    theta_sums_full[k] = theta_sum
theta_rank_df_full = pd.DataFrame(list(theta_sums_full.items()), columns=['model_id', 'score'])
theta_rank_df_full = theta_rank_df_full.sort_values(by='score', ascending=False).reset_index(drop=True)
theta_rank_df_full['rank'] = theta_rank_df_full.index + 1

ray.init()

response_data_id = ray.put(response_data)
student_names_id = ray.put(student_names)
test_names_id = ray.put(test_names)
mask_base_id = ray.put(mask_base)
ranked_df_id = ray.put(ranked_df)
theta_rank_df_full_id = ray.put(theta_rank_df_full)
problems_id = ray.put(list(problems))

@ray.remote(num_cpus=0.5)
def run_experiment(train_percentage, rep, condition,
                   response_data, student_names, test_names, mask_base,
                   ranked_df, theta_rank_df_full, problems,
                   mmdd, hhmm,
                   lr, max_epochs, embedding_dim, theta_max,
                   difficulty_base_max, difficulty_other_max, theta_init,
                   batch_size,
                   use_guessing
                   ):
    mask_final, _ = create_balanced_mask_with_fixed_test(response_data, mask_base,
                                                         train_percentage=train_percentage,
                                                         item_names=problems,
                                                         seed=rep)

    model = Standard3PLIRT(
        response_data=response_data,
        student_names=student_names,
        test_names=test_names,
        train_mask=mask_final,
        split_difficulty=split_difficulty,
        split_ability=split_ability,
        lr=lr,
        batch_size=batch_size,
        max_epochs=max_epochs,
        device="cuda",
        eps=1e-3,
        embedding_dim=embedding_dim,
        theta_max=theta_max,
        difficulty_base_max=difficulty_base_max,
        difficulty_other_max=difficulty_other_max,
        theta_init=theta_init,
        use_guessing=use_guessing
    )
    estimates = model.fit()
    metrics = model.evaluate_predictions()

    theta_sums = {}
    for k in estimates["theta"]:
        theta_sum = 0
        for k2 in estimates["theta"][k]:
            if "effect" not in k2:
                theta_sum += estimates["theta"][k][k2]
        theta_sums[k] = theta_sum

    theta_rank_df = pd.DataFrame(list(theta_sums.items()), columns=['model_id', 'score'])
    theta_rank_df = theta_rank_df.sort_values(by='score', ascending=False).reset_index(drop=True)
    theta_rank_df['rank'] = theta_rank_df.index + 1

    merged_ranks = pd.merge(
        theta_rank_df[['model_id', 'rank']],
        ranked_df[['model_id', 'rank']],
        on='model_id',
        suffixes=('_theta', '_other')
    )
    corr_acu_theta, p_value_acu_theta = spearmanr(merged_ranks['rank_theta'], merged_ranks['rank_other'])
    
    merged_theta = pd.merge(
        theta_rank_df_full[['model_id', 'rank']],
        theta_rank_df[['model_id', 'rank']],
        on='model_id',
        suffixes=('_theta1', '_theta2')
    )
    corr_theta_theta, p_value_theta = spearmanr(merged_theta['rank_theta1'], merged_theta['rank_theta2'])

    result = {
        "train_percentage": train_percentage,
        "rep": rep,
        "condition": condition,
        **metrics,
        "corr_acu_theta": corr_acu_theta,
        "corr_theta_theta": corr_theta_theta,
        "p_value_acu_theta": p_value_acu_theta,
        "p_value_theta": p_value_theta,
    }
    return result

n_repetitions = 10
condition = "split"

tasks = []
for train_percentage in train_percentages:
    for rep in range(n_repetitions):
        tasks.append(
            run_experiment.remote(
                train_percentage, rep, condition,
                ray.get(response_data_id),
                ray.get(student_names_id),
                ray.get(test_names_id),
                ray.get(mask_base_id),
                ray.get(ranked_df_id),
                ray.get(theta_rank_df_full_id),
                ray.get(problems_id),
                mmdd, hhmm,
                lr, max_epochs, embedding_dim, theta_max,
                difficulty_base_max, difficulty_other_max, theta_init,
                batch_size,use_guessing
            )
        )

results = ray.get(tasks)

results_df = pd.DataFrame(results)
print("All experiment results:")
print(results_df)

import mlflow

mlflow.set_experiment(experiment_name)

with mlflow.start_run(run_name=experiment_name) as run:

    mlflow.log_param("dataset_name", dataset_name)
    mlflow.log_param("irt_name", irt_name)
    mlflow.log_param("use_guessing", use_guessing)
    mlflow.log_param("pl", pl)
    mlflow.log_param("lr", lr)
    mlflow.log_param("max_epochs", max_epochs)
    mlflow.log_param("embedding_dim", embedding_dim)
    mlflow.log_param("theta_max", theta_max)
    mlflow.log_param("difficulty_base_max", difficulty_base_max)
    mlflow.log_param("difficulty_other_max", difficulty_other_max)
    mlflow.log_param("theta_init", theta_init)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("n_repetitions", n_repetitions)
    mlflow.log_param("train_percentages", list(train_percentages))
    mlflow.log_param("split_difficulty", split_difficulty)
    mlflow.log_param("split_ability", split_ability)
    mlflow.log_param("split_a", split_a)

    results_csv = f"result/{csv_name}.csv"
    results_df.to_csv(results_csv, index=False)
    mlflow.log_artifact(results_csv)

    avg_corr_acu_theta = results_df["corr_acu_theta"].mean()
    avg_corr_theta_theta = results_df["corr_theta_theta"].mean()
    mlflow.log_metric("avg_corr_acu_theta", avg_corr_acu_theta)
    mlflow.log_metric("avg_corr_theta_theta", avg_corr_theta_theta)

print("All experiments are done. And the results are saved.")


