import numpy as np
import csv
import os
import coma_utils as cu # Assuming coma_utils is accessible in this context

# --- DataTransformer Class ---
class DataTransformer:
    def __init__(self, use_log_transform=False, num_lags=0, feature_indices=None):
        self.use_log_transform = use_log_transform
        self.num_lags = num_lags # Primarily for AR-like feature selection from X
        self.feature_indices = feature_indices # For selecting specific columns if not just lags
        self._feature_warning_logged = False

    def _select_features(self, X):
        if self.feature_indices is not None:
            if X.ndim == 1:
                X_reshaped = X.reshape(1, -1)
                if X_reshaped.shape[1] > max(self.feature_indices):
                    return X_reshaped[:, self.feature_indices]
                else:
                    if not self._feature_warning_logged:
                        # print(f"Warning: Feature indices {self.feature_indices} out of bounds for X shape {X_reshaped.shape}. Using all available features.")
                        self._feature_warning_logged = True 
                    return X_reshaped
            elif X.shape[1] > max(self.feature_indices):
                return X[:, self.feature_indices]
            else:
                if not self._feature_warning_logged:
                    # print(f"Warning: Feature indices {self.feature_indices} out of bounds for X shape {X.shape}. Using all available features.")
                    self._feature_warning_logged = True
                return X
        elif self.num_lags > 0:
            if X.ndim == 1:
                 X_reshaped = X.reshape(1, -1)
                 return X_reshaped[:, :min(self.num_lags, X_reshaped.shape[1])]
            return X[:, :min(self.num_lags, X.shape[1])]
        return X

    def transform_features(self, X):
        X_selected = self._select_features(X)
        if self.use_log_transform:
            if np.any(X_selected <= 0):
                return X_selected 
            return np.log(X_selected)
        return X_selected

    def transform_target(self, Y):
        if self.use_log_transform:
            if np.any(Y <= 0):
                return Y
            return np.log(Y)
        return Y

    def inverse_transform_prediction(self, Y_pred_transformed):
        if self.use_log_transform:
            return np.exp(Y_pred_transformed)
        return Y_pred_transformed

# --- End DataTransformer Class ---

def create_forecaster_results_list_from_reported(reported_results_run, num_forecasters):
    """
    Transforms reported results for all forecasters into a list of individual forecaster results.
    """
    forecaster_results_list = []
    if (
        "reported_loss" not in reported_results_run
        or reported_results_run["reported_loss"].shape[1] != num_forecasters
    ):
        return []

    n_pred_local = reported_results_run["reported_loss"].shape[0]

    for j in range(num_forecasters):
        res_j = {
            "piAdapt": reported_results_run["reported_piAdapt"][:n_pred_local, j, :],
            "AdaptErr": reported_results_run["reported_AdaptErr"][:n_pred_local, j],
            "loss": reported_results_run["reported_loss"][:n_pred_local, j],
        }
        forecaster_results_list.append(res_j)
    return forecaster_results_list

def generate_reported_results_from_assignment(
    algo_results_list, n_pred, num_forecasters, assignment_matrix
):
    """
    Generates aggregated results for a pool of forecasters based on an assignment matrix.
    """
    num_base_algorithms = len(algo_results_list)
    if not (
        assignment_matrix.shape[0] == n_pred
        and assignment_matrix.shape[1] == num_forecasters
        and np.max(assignment_matrix) < num_base_algorithms
        and np.min(assignment_matrix) >= 0
    ):
        return {
            "reported_piAdapt": np.full((n_pred, num_forecasters, 2), np.nan),
            "reported_AdaptErr": np.full((n_pred, num_forecasters), np.nan),
            "reported_loss": np.full((n_pred, num_forecasters), np.inf),
            "reported_coverage_ind": np.full((n_pred, num_forecasters), np.nan),
            "assignment_history": np.array([]),
        }

    pi_hist = np.full((n_pred, num_forecasters, 2), np.nan)
    err_hist = np.full((n_pred, num_forecasters), np.nan)
    loss_hist = np.full((n_pred, num_forecasters), np.inf)

    for idx, assign_vec in enumerate(assignment_matrix):
        for j_fc, algo_idx in enumerate(assign_vec):
            if (
                algo_idx < num_base_algorithms
                and algo_results_list[algo_idx] is not None
                and algo_results_list[algo_idx]["piAdapt"].shape[0] > idx
                and not np.isnan(algo_results_list[algo_idx]["piAdapt"][idx, 0])
            ):

                pi_hist[idx, j_fc, :] = algo_results_list[algo_idx]["piAdapt"][idx]
                err_hist[idx, j_fc] = algo_results_list[algo_idx]["AdaptErr"][idx]
                loss_hist[idx, j_fc] = cu.loss_fun(
                    algo_results_list[algo_idx]["piAdapt"][idx]
                )

    return {
        "reported_piAdapt": pi_hist,
        "reported_AdaptErr": err_hist,
        "reported_loss": loss_hist,
        "reported_coverage_ind": 1.0 - err_hist,
        "assignment_history": assignment_matrix,
    }

def log_single_seed_results(config, data_for_logging): # Renamed from _log_single_seed_results
    """
    Logs the results from a single seed's scenarios to CSV files.
    """
    if data_for_logging is None:
        print(f"No data to log for seed {config.get('seed', 'UNKNOWN')}.")
        return

    seed = data_for_logging["seed"]
    n_pred = data_for_logging["n_pred"]
    tinit = data_for_logging["tinit"]
    K = data_for_logging["K_forecasters"]
    y_test_agg = data_for_logging["y_test_agg"]
    rep_res_individual_pool = data_for_logging["rep_res_individual_learners_pool"]

    base_model_results = data_for_logging.get("base_model_results", {})
    if base_model_results:
        print(f"\nBase Model Performance for seed {seed}:")
        for model_name, model_perf in base_model_results.items():
            avg_miscov = (
                1.0 - np.mean(model_perf["cov_maj"][np.isfinite(model_perf["cov_maj"])])
                if model_perf["cov_maj"].size > 0
                else np.nan
            )
            avg_len = (
                np.mean(model_perf["loss_maj"][np.isfinite(model_perf["loss_maj"])])
                if model_perf["loss_maj"].size > 0
                else np.nan
            )
            print(
                f"{model_name}: Avg Miscoverage = {avg_miscov:.4f}, Avg Length = {avg_len:.4f}"
            )

    adacoma_results = data_for_logging["adacoma_results"]
    adacoma_priors_ada = data_for_logging["adacoma_priors"]["adahedge"]
    adacoma_priors_hedge = data_for_logging["adacoma_priors"]["hedge"]

    coma_results = data_for_logging["coma_results"]
    coma_weights_ada = data_for_logging["coma_aggregation_weights"]["adahedge"]
    coma_weights_hedge = data_for_logging["coma_aggregation_weights"]["hedge"]

    detailed_log_dir = config["log_dir"]
    summary_log_dir = os.path.join(config["results_dir"], "summary_logs")
    os.makedirs(detailed_log_dir, exist_ok=True)
    os.makedirs(summary_log_dir, exist_ok=True)

    base_fname = f"{config['dataset']}_{config['splitting_strategy']}_seed{seed}"
    fpaths = {
        "adacoma_ada": os.path.join(
            detailed_log_dir, f"adacoma_adahedge_{base_fname}.csv"
        ),
        "adacoma_hedge": os.path.join(
            detailed_log_dir, f"adacoma_hedge_{base_fname}.csv"
        ),
        "coma_ada": os.path.join(detailed_log_dir, f"coma_adahedge_{base_fname}.csv"),
        "coma_hedge": os.path.join(detailed_log_dir, f"coma_hedge_{base_fname}.csv"),
        "base_models": os.path.join(detailed_log_dir, f"base_models_{base_fname}.csv"),
        "summary": os.path.join(summary_log_dir, f"summary_{base_fname}.csv"),
    }

    header_basic = ["time_step", "seed", "dataset", "splitting_strategy", "Y_true"]
    header_fc_perf = [f"fc_{i}_len" for i in range(K)] + [
        f"fc_{i}_cov" for i in range(K)
    ]
    header_base_model = [
        "model_name",
        "time_step",
        "seed",
        "dataset",
        "splitting_strategy",
        "Y_true",
        "miscov",
        "len",
        "raw_pred",
    ]

    headers = {
        "adacoma_ada": header_basic
        + header_fc_perf
        + [f"adacoma_ada_prior_w_{i}" for i in range(K)]
        + [f"adacoma_ada_sel_prob_{i}" for i in range(K)]
        + [
            "adacoma_ada_sel_idx",
            "adacoma_ada_miscov",
            "adacoma_ada_len",
            "adacoma_ada_exp_cov",
            "adacoma_ada_exp_len",
        ],
        "adacoma_hedge": header_basic
        + header_fc_perf
        + [f"adacoma_hedge_prior_w_{i}" for i in range(K)]
        + [f"adacoma_hedge_sel_prob_{i}" for i in range(K)]
        + [
            "adacoma_hedge_sel_idx",
            "adacoma_hedge_miscov",
            "adacoma_hedge_len",
            "adacoma_hedge_exp_cov",
            "adacoma_hedge_exp_len",
        ],
        "coma_ada": header_basic
        + header_fc_perf
        + [f"coma_ada_w_{i}" for i in range(K)]
        + ["coma_ada_miscov", "coma_ada_len", "coma_ada_exp_cov", "coma_ada_exp_len"],
        "coma_hedge": header_basic
        + header_fc_perf
        + [f"coma_hedge_w_{i}" for i in range(K)]
        + [
            "coma_hedge_miscov",
            "coma_hedge_len",
            "coma_hedge_exp_cov",
            "coma_hedge_exp_len",
        ],
        "base_models": header_base_model,
    }

    writers = {}
    for key, fpath in fpaths.items():
        if key != "summary":
            f = open(fpath, "w", newline="")
            writers[key] = {"file": f, "writer": csv.writer(f)}
            writers[key]["writer"].writerow(headers[key])

    if base_model_results:
        for t in range(n_pred):
            for model_name, model_perf in base_model_results.items():
                if model_perf["cov_maj"].shape[0] > t:
                    row_base = {
                        "model_name": model_name,
                        "time_step": t + tinit,
                        "seed": seed,
                        "dataset": config["dataset"],
                        "splitting_strategy": config["splitting_strategy"],
                        "Y_true": y_test_agg[t] if t < len(y_test_agg) else np.nan,
                        "miscov": (
                            1.0 - model_perf["cov_maj"][t]
                            if np.isfinite(model_perf["cov_maj"][t])
                            else np.nan
                        ),
                        "len": (
                            model_perf["loss_maj"][t]
                            if np.isfinite(model_perf["loss_maj"][t])
                            else np.nan
                        ),
                        "raw_pred": (
                            model_perf["raw_pred"][t]
                            if "raw_pred" in model_perf
                            and model_perf["raw_pred"].shape[0] > t
                            else np.nan
                        ),
                    }
                    writers["base_models"]["writer"].writerow(
                        [row_base.get(h) for h in headers["base_models"]]
                    )

    for t in range(n_pred):
        base_row_data = {
            "time_step": t + tinit,
            "seed": seed,
            "dataset": config["dataset"],
            "splitting_strategy": config["splitting_strategy"],
            "Y_true": y_test_agg[t] if t < len(y_test_agg) else np.nan,
        }
        for k_fc in range(K): 
            base_row_data[f"fc_{k_fc}_len"] = (
                rep_res_individual_pool["reported_loss"][t, k_fc]
                if rep_res_individual_pool["reported_loss"].ndim == 2
                and rep_res_individual_pool["reported_loss"].shape[0] > t
                and rep_res_individual_pool["reported_loss"].shape[1] > k_fc
                else np.nan
            )
            base_row_data[f"fc_{k_fc}_cov"] = (
                rep_res_individual_pool["reported_coverage_ind"][t, k_fc]
                if rep_res_individual_pool["reported_coverage_ind"].ndim == 2
                and rep_res_individual_pool["reported_coverage_ind"].shape[0] > t
                and rep_res_individual_pool["reported_coverage_ind"].shape[1] > k_fc
                else np.nan
            )

        row_adacoma_ada = {**base_row_data}
        res_aa = adacoma_results.get("adacoma_ada", {})
        for k_fc in range(K):
            row_adacoma_ada[f"adacoma_ada_prior_w_{k_fc}"] = (
                adacoma_priors_ada[t, k_fc]
                if adacoma_priors_ada.ndim == 2
                and adacoma_priors_ada.shape[0] > t
                and adacoma_priors_ada.shape[1] > k_fc
                else np.nan
            )
            row_adacoma_ada[f"adacoma_ada_sel_prob_{k_fc}"] = (
                res_aa.get("p_weights", np.empty((0, K)))[t, k_fc]
                if res_aa.get("p_weights", np.empty((0, K))).ndim == 2
                and res_aa.get("p_weights", np.empty((0, K))).shape[0] > t
                and res_aa.get("p_weights", np.empty((0, K))).shape[1] > k_fc
                else np.nan
            )
        row_adacoma_ada["adacoma_ada_sel_idx"] = (
            res_aa.get("idx_selected_hist", np.array([-1] * n_pred))[t]
            if res_aa.get("idx_selected_hist", np.array([-1] * n_pred)).shape[0] > t
            else -1
        )
        if res_aa.get("cov_maj", np.array([])).shape[0] > t:
            row_adacoma_ada["adacoma_ada_miscov"] = 1.0 - res_aa["cov_maj"][t]
            row_adacoma_ada["adacoma_ada_len"] = res_aa["loss_maj"][t]
            row_adacoma_ada["adacoma_ada_exp_cov"] = res_aa["exp_cov"][t]
            row_adacoma_ada["adacoma_ada_exp_len"] = res_aa["exp_len"][t]
        writers["adacoma_ada"]["writer"].writerow(
            [row_adacoma_ada.get(h) for h in headers["adacoma_ada"]]
        )

        row_adacoma_hedge = {**base_row_data}
        res_ah = adacoma_results.get("adacoma_hedge", {})
        for k_fc in range(K):
            row_adacoma_hedge[f"adacoma_hedge_prior_w_{k_fc}"] = (
                adacoma_priors_hedge[t, k_fc]
                if adacoma_priors_hedge.ndim == 2
                and adacoma_priors_hedge.shape[0] > t
                and adacoma_priors_hedge.shape[1] > k_fc
                else np.nan
            )
            row_adacoma_hedge[f"adacoma_hedge_sel_prob_{k_fc}"] = (
                res_ah.get("p_weights", np.empty((0, K)))[t, k_fc]
                if res_ah.get("p_weights", np.empty((0, K))).ndim == 2
                and res_ah.get("p_weights", np.empty((0, K))).shape[0] > t
                and res_ah.get("p_weights", np.empty((0, K))).shape[1] > k_fc
                else np.nan
            )
        row_adacoma_hedge["adacoma_hedge_sel_idx"] = (
            res_ah.get("idx_selected_hist", np.array([-1] * n_pred))[t]
            if res_ah.get("idx_selected_hist", np.array([-1] * n_pred)).shape[0] > t
            else -1
        )
        if res_ah.get("cov_maj", np.array([])).shape[0] > t:
            row_adacoma_hedge["adacoma_hedge_miscov"] = 1.0 - res_ah["cov_maj"][t]
            row_adacoma_hedge["adacoma_hedge_len"] = res_ah["loss_maj"][t]
            row_adacoma_hedge["adacoma_hedge_exp_cov"] = res_ah["exp_cov"][t]
            row_adacoma_hedge["adacoma_hedge_exp_len"] = res_ah["exp_len"][t]
        writers["adacoma_hedge"]["writer"].writerow(
            [row_adacoma_hedge.get(h) for h in headers["adacoma_hedge"]]
        )

        row_coma_ada = {**base_row_data}
        res_ca = coma_results.get("coma_ada", {})
        for k_fc in range(K):
            row_coma_ada[f"coma_ada_w_{k_fc}"] = (
                coma_weights_ada[t, k_fc]
                if coma_weights_ada.ndim == 2
                and coma_weights_ada.shape[0] > t
                and coma_weights_ada.shape[1] > k_fc
                else np.nan
            )
        if res_ca.get("cov_maj", np.array([])).shape[0] > t:
            row_coma_ada["coma_ada_miscov"] = 1.0 - res_ca["cov_maj"][t]
            row_coma_ada["coma_ada_len"] = res_ca["loss_maj"][t]
            row_coma_ada["coma_ada_exp_cov"] = res_ca["exp_cov"][t]
            row_coma_ada["coma_ada_exp_len"] = res_ca["exp_len"][t]
        writers["coma_ada"]["writer"].writerow(
            [row_coma_ada.get(h) for h in headers["coma_ada"]]
        )

        row_coma_hedge = {**base_row_data}
        res_ch = coma_results.get("coma_hedge", {})
        for k_fc in range(K):
            row_coma_hedge[f"coma_hedge_w_{k_fc}"] = (
                coma_weights_hedge[t, k_fc]
                if coma_weights_hedge.ndim == 2
                and coma_weights_hedge.shape[0] > t
                and coma_weights_hedge.shape[1] > k_fc
                else np.nan
            )
        if res_ch.get("cov_maj", np.array([])).shape[0] > t:
            row_coma_hedge["coma_hedge_miscov"] = 1.0 - res_ch["cov_maj"][t]
            row_coma_hedge["coma_hedge_len"] = res_ch["loss_maj"][t]
            row_coma_hedge["coma_hedge_exp_cov"] = res_ch["exp_cov"][t]
            row_coma_hedge["coma_hedge_exp_len"] = res_ch["exp_len"][t]
        writers["coma_hedge"]["writer"].writerow(
            [row_coma_hedge.get(h) for h in headers["coma_hedge"]]
        )

    for key in writers:
        writers[key]["file"].close()

    summary_header = [
        "config_name",
        "avg_exp_miscov",
        "avg_exp_len",
    ]
    summary_data_rows = []

    def calculate_averages(method_results, default_val=np.nan):
        avg_exp_miscov, avg_exp_len = (
            default_val,
            default_val,
        )
        if method_results:
            if method_results.get("exp_cov", np.array([])).size > 0:
                valid_exp_cov = method_results["exp_cov"][
                    np.isfinite(method_results["exp_cov"])
                ]
                if valid_exp_cov.size > 0:
                    avg_exp_miscov = 1.0 - np.mean(valid_exp_cov)
            if method_results.get("exp_len", np.array([])).size > 0:
                valid_exp_len = method_results["exp_len"][
                    np.isfinite(method_results["exp_len"])
                ]
                if valid_exp_len.size > 0:
                    avg_exp_len = np.mean(valid_exp_len)
        return avg_exp_miscov, avg_exp_len

    summary_data_rows.append(
        ["adacoma_adahedge"]
        + list(calculate_averages(adacoma_results.get("adacoma_ada")))
    )
    summary_data_rows.append(
        ["adacoma_hedge"]
        + list(calculate_averages(adacoma_results.get("adacoma_hedge")))
    )
    summary_data_rows.append(
        ["coma_adahedge"] + list(calculate_averages(coma_results.get("coma_ada")))
    )
    summary_data_rows.append(
        ["coma_hedge"] + list(calculate_averages(coma_results.get("coma_hedge")))
    )

    with open(fpaths["summary"], "w", newline="") as f_summary:
        writer_summary = csv.writer(f_summary)
        writer_summary.writerow(summary_header)
        writer_summary.writerows(summary_data_rows)

    print(
        f"Logging completed for seed {seed}. Detailed: {detailed_log_dir}, Summary: {fpaths['summary']}"
    ) 