# run_experiment.py
import argparse
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression, SGDRegressor
from tqdm import tqdm
import os
import time
import csv
from copy import deepcopy

import data_utils
import coma_utils as cu
from assignments import (
    create_assignment_matrix_random_switch,
)
import selection as sel
from utils import (
    create_forecaster_results_list_from_reported,
    generate_reported_results_from_assignment,
    log_single_seed_results,
)
from model_utils import generate_model_outputs


def create_forecaster_results_list_from_reported(reported_results_run, num_forecasters):
    """
    Transforms reported results for all forecasters into a list of individual forecaster results.
    """
    forecaster_results_list = []
    if (
        "reported_loss" not in reported_results_run
        or reported_results_run["reported_loss"].shape[1] != num_forecasters
    ):
        return []

    n_pred_local = reported_results_run["reported_loss"].shape[0]

    for j in range(num_forecasters):
        res_j = {
            "piAdapt": reported_results_run["reported_piAdapt"][:n_pred_local, j, :],
            "AdaptErr": reported_results_run["reported_AdaptErr"][:n_pred_local, j],
            "loss": reported_results_run["reported_loss"][:n_pred_local, j],
        }
        forecaster_results_list.append(res_j)
    return forecaster_results_list


def run_base_aci_algorithms(
    Y_full, # Full Y series
    alpha, 
    q_init, 
    nu, 
    eps, 
    tinit,
    all_model_raw_outputs # List of dicts, each with ["model_name", "raw_predictions"]
):
    """
    Runs the Adaptive Conformal Inference (ACI) algorithm for multiple models
    using pre-generated raw model predictions.
    Returns a list of ACI results for each model.
    """
    num_algorithms = len(all_model_raw_outputs)
    T = len(Y_full)
    if not (0 < tinit < T):
        # Safety check for tinit if function called directly.
        return [], 0, tinit 

    n_pred = T - tinit # Number of prediction steps
    if n_pred <= 0:
        return [], 0, tinit

    for i, model_output in enumerate(all_model_raw_outputs):
        if len(model_output["raw_predictions"]) != n_pred:
            raise ValueError(
                f"Model {model_output.get('model_name', i)} has mismatched raw_predictions length. Expected {n_pred}, got {len(model_output['raw_predictions'])}"
            )

    quantiles_t = np.full(num_algorithms, float(q_init))
    # ar_model_adaptive_nus[i] stores the nu to be used for model i's *current* quantile update.
    ar_model_adaptive_nus = np.full(num_algorithms, float(nu)) 

    results = [
        {
            "q_t": np.full(n_pred, np.nan),
            "AdaptErr": np.full(n_pred, np.nan),
            "piAdapt": np.full((n_pred, 2), np.nan),
            "raw_pred": np.full(n_pred, np.nan),
            "raw_score": np.full(n_pred, np.nan), 
        }
        for _ in range(num_algorithms)
    ]

    for t_pred_idx in range(n_pred): 
        t_data_idx = tinit + t_pred_idx # Index in the original Y_full array for the true value
        Y_t_actual = Y_full[t_data_idx]

        for i, model_output_data in enumerate(all_model_raw_outputs):
            pred_t = model_output_data["raw_predictions"][t_pred_idx]
            results[i]["raw_pred"][t_pred_idx] = pred_t

            score_t = np.abs(Y_t_actual - pred_t) if np.isfinite(pred_t) else np.inf
            results[i]["raw_score"][t_pred_idx] = score_t
            
            q_curr = quantiles_t[i]
            results[i]["q_t"][t_pred_idx] = q_curr  
            lo, hi = pred_t - q_curr, pred_t + q_curr
            results[i]["piAdapt"][t_pred_idx, :] = [lo, hi]
            
            err_t = (
                1.0 - float(lo <= Y_t_actual <= hi)
                if np.isfinite(lo) and np.isfinite(hi) and hi >= lo
                else 1.0
            )
            results[i]["AdaptErr"][t_pred_idx] = err_t

            nu_for_this_q_update = ar_model_adaptive_nus[i]
            quantiles_t[i] = max(0, q_curr + nu_for_this_q_update * (err_t - alpha))

            # For the *next* prediction step (t_pred_idx + 1), update ar_model_adaptive_nus[i].
            # This adaptive nu is based on a window of past raw scores for model i.
            # The window size for nu adaptation is `tinit` prediction steps if available, 
            # otherwise all available past scores since predictions began.

            if t_pred_idx >= tinit: 
                start_score_window_idx = max(0, t_pred_idx - tinit) 
                scores_in_window = results[i]["raw_score"][start_score_window_idx : t_pred_idx + 1]
            else: 
                scores_in_window = results[i]["raw_score"][ : t_pred_idx + 1]
            
            new_adaptive_nu_for_next_step = float(nu) # Default to initial nu from args
            if len(scores_in_window) > 0:
                finite_scores_in_window = scores_in_window[np.isfinite(scores_in_window)]
                if len(finite_scores_in_window) > 0:
                    max_recent_error = np.max(finite_scores_in_window)
                    new_adaptive_nu_for_next_step = 0.1 * max_recent_error 
            ar_model_adaptive_nus[i] = new_adaptive_nu_for_next_step

    return results


def generate_reported_results_from_assignment(
    algo_results_list, n_pred, num_forecasters, assignment_matrix
):
    """
    Generates aggregated results for a pool of forecasters based on an assignment matrix.
    """
    num_base_algorithms = len(algo_results_list)
    if not (
        assignment_matrix.shape[0] == n_pred
        and assignment_matrix.shape[1] == num_forecasters
        and np.max(assignment_matrix) < num_base_algorithms
        and np.min(assignment_matrix) >= 0
    ):
        return {
            "reported_piAdapt": np.full((n_pred, num_forecasters, 2), np.nan),
            "reported_AdaptErr": np.full((n_pred, num_forecasters), np.nan),
            "reported_loss": np.full((n_pred, num_forecasters), np.inf),
            "reported_coverage_ind": np.full((n_pred, num_forecasters), np.nan),
            "assignment_history": np.array([]),
        }

    pi_hist = np.full((n_pred, num_forecasters, 2), np.nan)
    err_hist = np.full((n_pred, num_forecasters), np.nan)
    loss_hist = np.full((n_pred, num_forecasters), np.inf)

    for idx, assign_vec in enumerate(assignment_matrix):
        for j_fc, algo_idx in enumerate(assign_vec):
            if (
                algo_idx < num_base_algorithms
                and algo_results_list[algo_idx] is not None
                and algo_results_list[algo_idx]["piAdapt"].shape[0] > idx
                and not np.isnan(algo_results_list[algo_idx]["piAdapt"][idx, 0])
            ):

                pi_hist[idx, j_fc, :] = algo_results_list[algo_idx]["piAdapt"][idx]
                err_hist[idx, j_fc] = algo_results_list[algo_idx]["AdaptErr"][idx]
                loss_hist[idx, j_fc] = cu.loss_fun(
                    algo_results_list[algo_idx]["piAdapt"][idx]
                )

    return {
        "reported_piAdapt": pi_hist,
        "reported_AdaptErr": err_hist,
        "reported_loss": loss_hist,
        "reported_coverage_ind": 1.0 - err_hist,
        "assignment_history": assignment_matrix,
    }


def _perform_single_seed_logic(config):
    """
    Performs the core computational logic for a single experiment seed.
    Returns a dictionary of results needed for logging, or None on failure.
    """
    seed = config["seed"]
    np.random.seed(seed)

    Y_np, X_dict, n_obs = data_utils.load_and_preprocess_dataset(
        config["dataset"], config
    )
    if Y_np is None or n_obs == 0:
        raise ValueError(f"Data loading failed for seed {seed}.")

    tinit = config["tinit"]
    if not (0 < tinit < n_obs):
        raise ValueError(f"Invalid tinit {tinit} for n_obs {n_obs}, seed {seed}.")

    all_model_raw_outputs = generate_model_outputs(
        Y_np, X_dict, config["algo_configs"], tinit, seed
    )

    #TODO to be changed after debugging
    alpha_prime = config["alpha_final"]*0.9# * config["alpha_prime_ratio"]

    res_adacoma_base_learners = run_base_aci_algorithms(
        Y_np, 
        alpha_prime,
        config["q_init"],
        config["nu"],
        config["eps"],
        tinit,
        all_model_raw_outputs
    )
    res_coma_base_learners = run_base_aci_algorithms(
        Y_np, 
        config["alpha_final"],
        config["q_init"],
        config["nu"],
        config["eps"],
        tinit,
        all_model_raw_outputs
    )

    base_model_results = {}
    for i, model_output_spec in enumerate(all_model_raw_outputs):
        model_name = model_output_spec["model_name"]
        if (
            res_adacoma_base_learners
            and i < len(res_adacoma_base_learners)
            and res_adacoma_base_learners[i] is not None
        ):
            base_model_piAdapt = res_adacoma_base_learners[i]["piAdapt"]
            # Calculate interval lengths: hi - lo
            current_interval_lengths = base_model_piAdapt[:, 1] - base_model_piAdapt[:, 0]
            # Handle cases where lo/hi might be NaN (e.g., if pred_t was NaN)
            # A NaN length should be treated as infinitely wide, consistent with other loss handling
            current_interval_lengths_processed = np.nan_to_num(
                current_interval_lengths, nan=np.inf, posinf=np.inf, neginf=0.0
            )

            base_model_results[model_name] = {
                "cov_maj": 1.0 - res_adacoma_base_learners[i]["AdaptErr"], # Array of coverage indicators (1-error)
                "loss_maj": current_interval_lengths_processed,    # Array of actual interval lengths (for alpha_prime)
                "raw_pred": res_adacoma_base_learners[i]["raw_pred"],
                "raw_score": res_adacoma_base_learners[i]["raw_score"], # Keep original raw_score
            }

    n_pred = len(Y_np) - tinit

    setting_type = config["splitting_strategy"]
    assign_params = config["assignment_params"].get(setting_type, {})
    K, M = config["K_forecasters"], len(config["algo_configs"])
    assignment_matrix = None
    if setting_type == "random_switch":
        assignment_matrix = create_assignment_matrix_random_switch(
            n_pred, K, M, assign_params.get("switch_period", 25), seed
        )

    if assignment_matrix is None or assignment_matrix.shape[0] != n_pred:
        raise ValueError(f"Assignment matrix error for seed {seed}.")

    rep_res_adacoma_pool = generate_reported_results_from_assignment(
        res_adacoma_base_learners, n_pred, K, assignment_matrix
    )
    rep_res_coma_pool = generate_reported_results_from_assignment(
        res_coma_base_learners, n_pred, K, assignment_matrix
    )

    y_test_agg = Y_np[tinit : tinit + n_pred]

    loss_mat_adacoma_learners = rep_res_adacoma_pool["reported_loss"]
    loss_mat_adacoma_learners_finite = np.nan_to_num(
        loss_mat_adacoma_learners, nan=np.inf, posinf=np.inf, neginf=0.0
    )
    weights_prior_adacoma_ada = cu.adahedge(loss_mat_adacoma_learners_finite)["weights"]
    weights_prior_adacoma_hedge = cu.hedge(
        loss_mat_adacoma_learners_finite, eta=config["eta_hedge"]
    )["weights"]

    adacoma_agg_results = sel.run_adacoma_aggregation(
        create_forecaster_results_list_from_reported(
            rep_res_adacoma_pool, K
        ),  
        y_test_agg,
        n_pred,
        rep_res_adacoma_pool[
            "reported_coverage_ind"
        ],  
        loss_mat_adacoma_learners,  
        weights_prior_adacoma_ada,  
        weights_prior_adacoma_hedge,  
        config["alpha_final"], #TODO to be changed after debugging
        config["alpha_final"]*4, #TODO to be changed after debugging
    )

    loss_mat_coma_learners = rep_res_coma_pool["reported_loss"]
    coma_agg_results, coma_weights_ada, coma_weights_hedge = sel.run_coma_aggregation(
        results_list=create_forecaster_results_list_from_reported(
            rep_res_coma_pool, K
        ),  
        y_test=y_test_agg,
        n_pred=n_pred,
        loss_matrix=loss_mat_coma_learners,  
        raw_ind_coverage=rep_res_coma_pool[
            "reported_coverage_ind"
        ],  
        raw_ind_lengths=loss_mat_coma_learners,  
        eta_hedge=config["eta_hedge"],
    )

    return {
        "seed": seed,
        "n_pred": n_pred,
        "tinit": tinit,
        "K_forecasters": K,
        "y_test_agg": y_test_agg,
        "rep_res_individual_learners_pool": rep_res_adacoma_pool,
        "base_model_results": base_model_results,
        "adacoma_results": adacoma_agg_results,
        "adacoma_priors": {
            "adahedge": weights_prior_adacoma_ada,
            "hedge": weights_prior_adacoma_hedge,
        },
        "coma_results": coma_agg_results,
        "coma_aggregation_weights": {
            "adahedge": coma_weights_ada,
            "hedge": coma_weights_hedge,
        },
    }


def run_online_experiment_trial(config):
    """
    Main wrapper to run a single experiment trial, separating logic and logging.
    """
    print(f"--- Starting Trial for Seed: {config['seed']} ---")
    experiment_data = _perform_single_seed_logic(config)
    log_single_seed_results(config, experiment_data)

    print(f"--- Finished Trial for Seed: {config['seed']} (Success) ---")
    return {
        "status": "success",
        "seed": config["seed"],
        "n_pred": experiment_data.get("n_pred"),
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Online Experiment Runner")
    parser.add_argument(
        "--dataset", type=str, default="elec", choices=["elec", "aram"]
    )
    parser.add_argument(
        "--splitting_strategy",
        type=str,
        default="random_switch",
        choices=["random_switch"],
    )
    parser.add_argument("--num_seeds", type=int, default=10)
    parser.add_argument("--start_seed", type=int, default=0)
    parser.add_argument(
        "--log_dir", type=str, default="experiment_results/detailed_logs/"
    )
    parser.add_argument("--results_dir", type=str, default="experiment_results/")
    parser.add_argument("--data_dir", type=str, default="data/")
    parser.add_argument("--tinit", type=int, default=100)
    parser.add_argument("--alpha_final", type=float, default=0.1)
    parser.add_argument("--alpha_prime_ratio", type=float, default=0.8)
    parser.add_argument("--q_init", type=float, default=0.0)
    parser.add_argument("--nu", type=float, default=2.0)
    parser.add_argument("--eps", type=float, default=0.01)
    parser.add_argument("--eta_hedge", type=float, default=0.1)
    parser.add_argument(
        "--K_forecasters", type=int, default=10
    )  # Number of virtual forecasters in the pool
    parser.add_argument("--max_lag", type=int, default=2)

    # Random_switch splitting
    parser.add_argument("--random_switch_period", type=int, default=50)
    args = parser.parse_args()
    config = vars(args)

    # All datasets will now have use_log_transform = False in their model configs.

    # Define Base ACI Model Configurations
    model_configs_list = []

    sgd_base_params_shared = {
        "learning_rate": "constant",
        "max_iter": 1,
        "tol": None,
    }

    model_configs_list.append({
        "name": "Lasso_alpha0.1",
        "model_type": "online_sgd", 
        "feature_set": "all",
        "params": {
            "sgd_params": {**sgd_base_params_shared, "eta0": 0.001, "penalty": "l1", "alpha": 0.1},
            "data_transform_params": {"use_log_transform": False}
        },
    })
    model_configs_list.append({
        "name": "Ridge_alpha0.1",
        "model_type": "online_sgd", 
        "feature_set": "all",
        "params": {
            "sgd_params": {**sgd_base_params_shared, "eta0": 0.001, "penalty": "l2", "alpha": 0.1},
            "data_transform_params": {"use_log_transform": False}
        },
    })
    
    sgd_etas_other = [0.001, 0.005] 
    for eta_val_other in sgd_etas_other:
        model_configs_list.append({
            "name": f"SGD_eta{eta_val_other}",
            "model_type": "online_sgd",
            "feature_set": "all",
            "params": {
                "sgd_params": {**sgd_base_params_shared, "eta0": eta_val_other, "penalty": None, "alpha": 0.0}, 
                "data_transform_params": {"use_log_transform": False}
            }
        })
    
    rolling_windows_other = [50, 100]
    for rw_win_other in rolling_windows_other:
        model_configs_list.append({
            "name": f"RollingLM_w{rw_win_other}", 
            "model_type": "rolling_lm",
            "feature_set": "all",
            "params": {
                "window_size": rw_win_other,
                "retrain_freq": max(1, rw_win_other // 4),
                "data_transform_params": {"use_log_transform": False} 
            }
        })

    config["algo_configs"] = model_configs_list
    config["M_base_models"] = len(
        config["algo_configs"]
    )  

    # Assignment parameters
    config["assignment_params"] = {
        "random_switch": {"switch_period": args.random_switch_period},
    }
 

    print(
        "Effective configuration for runs:",
        {k: v for k, v in config.items() if k != "algo_configs"},
        "\nAlgo_configs_count:",
        len(config["algo_configs"]),
    )

    overall_start_time = time.time()
    successful_seeds = 0
    failed_seeds = 0

    for i in range(args.num_seeds):
        current_seed = args.start_seed + i
        run_config = deepcopy(config)
        run_config["seed"] = current_seed

        result = run_online_experiment_trial(run_config)
        if (
            result and result.get("status") == "success"
        ):  
            successful_seeds += 1
        else:
            failed_seeds += 1
            status_msg = (
                result.get("status", "unknown_failure") if result else "unknown_failure"
            )

    print(f"\n--- All Experiments Finished ---")
    print(f"Total script time: {time.time() - overall_start_time:.2f} seconds.")
    print(f"Successfully completed seeds: {successful_seeds}")
    print(f"Failed/Skipped seeds: {failed_seeds}")
    print(f"Detailed logs saved in: {os.path.abspath(config['log_dir'])}") 