from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from bayes_models import posterior_predictive_check  # type: ignore
from bayes_models import (IIAModelHandcrafted, IIAModelLeaveOneOut, nll_error,
                          squared_relative_error)
from freq_models import goodness_of_fit_test  # type: ignore
from freq_models import (goodness_of_fit_test_handcrafted,
                         mcfadden_train_tye_tests)
from pandas import DataFrame, Series, concat
from scipy.stats import chi2  # type: ignore
from synthetic_data import synthetic_additive_context  # type: ignore
from synthetic_data import synthetic_IIA, synthetic_simple_context
from tqdm import tqdm  # type: ignore
from utils import load_handcrafted_survey  # type: ignore
from utils import load_two_phase_survey


def model_fit_ppc(
    full_counts: np.ndarray,
    leave_one_out_counts: List[np.ndarray],
    mcmc_params: Optional[Dict[str, Any]] = None,
) -> Tuple[float, Series]:

    model = IIAModelLeaveOneOut(
        full_counts.shape[0], max_options=full_counts.shape[1], hp_std_v=2
    )
    trace = model.fit(full_counts, leave_one_out_counts, mcmc_params)
    data_dict = {
        "full": full_counts,
        **{
            f"rem_{i}": le_count
            for i, le_count in enumerate(leave_one_out_counts)
        },
    }
    agg_pval = posterior_predictive_check(
        full_counts.shape[0],
        data_dict,
        trace.posterior,
        squared_relative_error,
        aggregate_pvals=True,
    )
    pvals = posterior_predictive_check(
        full_counts.shape[0],
        data_dict,
        trace.posterior,
        squared_relative_error,
        aggregate_pvals=False,
    )
    return agg_pval, pvals


def p_values_comparison(
    full_counts: np.ndarray, leave_one_out_counts: List[np.ndarray]
) -> tuple[DataFrame, Dict[str, float]]:
    optimize_kwargs = {
        "step_size": 0.005,
        "stopping_delta": 0.0001,
        "max_iter": 10000,
    }
    mtt_results = DataFrame(
        mcfadden_train_tye_tests(
            full_counts, leave_one_out_counts, optimize_kwargs=optimize_kwargs
        )
    )
    gft_results = DataFrame(
        goodness_of_fit_test(
            full_counts, leave_one_out_counts, optimize_kwargs=optimize_kwargs
        )
    )
    ppc_pval, ppc_pvals = model_fit_ppc(
        full_counts,
        leave_one_out_counts,
        mcmc_params={"draws": 3000, "tune": 2000},
    )
    mtt_res_first = (
        mtt_results.sort_values("rem")
        .drop_duplicates("question")
        .set_index("question")
        .sort_index()
    )
    mtt_pvals = mtt_res_first["p-val"]
    mtt_stats = mtt_res_first["stat"]
    mtt_dgfs = mtt_res_first["dgf"]
    gft_pvals = gft_results["p-val"]
    gft_stats = gft_results["stat"]
    gft_dgfs = gft_results["dgf"]
    pvals = ppc_pvals.to_frame("ppc-pval")
    pvals["mtt-pval"] = mtt_pvals
    pvals["mtt-stat"] = mtt_stats
    pvals["mtt-gdf"] = mtt_dgfs
    pvals["gft-pval"] = gft_pvals
    pvals["gft-stat"] = gft_stats
    pvals["gft-dgf"] = gft_dgfs
    gft_pval = chi2.sf(gft_stats.sum(), df=gft_dgfs.sum())
    mtt_pval = chi2.sf(mtt_stats.sum(), df=mtt_dgfs.sum())
    agg_pvals = {
        "min_ppc": ppc_pvals.min(),
        "min_mtt": mtt_pvals.min(),
        "min_gft": gft_pvals.min(),
        "agg_ppc": ppc_pval,
        "agg_mtt": mtt_pval,
        "agg_gft": gft_pval,
    }
    return pvals, agg_pvals


def goodness_of_fit_type1_error():
    optimize_kwargs = {
        "step_size": 0.01,
        "stopping_delta": 0.0001,
        "max_iter": 1000,
    }
    rows = []
    for i in tqdm(range(10, 50)):
        for _ in range(2):
            m4, m3 = synthetic_IIA(i, 100, 2, max_options=4)
            gft_results = goodness_of_fit_test(
                m4,
                m3,
                optimize_kwargs=optimize_kwargs,
            )
            stat, dgf, n_params = 0.0, 0, 0
            for res in gft_results:
                stat += res["stat"]
                dgf += res["dgf"]
                n_params += res["n_params"]
            rows.append(
                {
                    "n": i,
                    "stat": stat,
                    "dgf": dgf,
                    "n_params": n_params,
                    "p-val-gdf-u": chi2.sf(stat, df=dgf),
                    "p-val-gdf-l": chi2.sf(stat, df=dgf - n_params),
                }
            )
    return DataFrame(rows)


def sweep_additive_context():
    rows = []
    for std_c in tqdm(np.linspace(0, 1, 11)):
        for _ in range(10):
            m4, m3 = synthetic_additive_context(
                30, 100, 2, std_c, max_options=4
            )
            _, agg_pvals = p_values_comparison(m4, m3)
            rows.append({"std_c": std_c, **agg_pvals})
            print(rows[-1])
    return DataFrame(rows)


def sweep_simple_context():
    rows = []
    for std_c in tqdm(np.linspace(0, 1, 11)):
        for _ in range(10):
            m4, m3 = synthetic_simple_context(30, 100, 2, std_c, max_options=4)
            _, agg_pvals = p_values_comparison(m4, m3)
            rows.append({"std_c": std_c, **agg_pvals})
            print(rows[-1])
    return DataFrame(rows)


def simulate_IIA():
    rows = []
    for std_c in tqdm(range(10)):
        m4, m3 = synthetic_IIA(30, 100, 2, max_options=4)
        _, agg_pvals = p_values_comparison(m4, m3)
        rows.append(agg_pvals)
        print(rows[-1])
    return DataFrame(rows)


def sweep_additive_context_store_all(std_cs, n_sims) -> DataFrame:
    dfs = []
    for std_c in tqdm(std_cs):
        for _ in range(n_sims):
            m4, m3 = synthetic_additive_context(
                30, 100, 2, std_c, max_options=4
            )
            df_pvals, agg_pvals = p_values_comparison(m4, m3)
            df_pvals["std_c"] = std_c
            dfs.append(df_pvals.reset_index())
    return concat(dfs, ignore_index=True)


def sweep_simple_context_store_all(std_cs, n_sims) -> DataFrame:
    dfs = []
    for std_c in tqdm(std_cs):
        for _ in range(n_sims):
            m4, m3 = synthetic_simple_context(30, 100, 2, std_c, max_options=4)
            df_pvals, agg_pvals = p_values_comparison(m4, m3)
            df_pvals["std_c"] = std_c
            dfs.append(df_pvals.reset_index())
    return concat(dfs, ignore_index=True)


def random_survey_analysis(
    first_phase_fpath: str, second_phase_fpath: str, read_kwargs=None
) -> Tuple[DataFrame, Dict[str, float]]:
    # Return p-values per question set, indexed by full question index?
    m4, m3 = load_two_phase_survey(
        first_phase_fpath, second_phase_fpath, read_kwargs
    )
    stats, agg_pvals = p_values_comparison(
        m4.to_numpy().astype(int), [m.to_numpy().astype(int) for m in m3]
    )
    stats.index = m4.index

    return stats, agg_pvals


def handcrafted_survey_analysis(
    survey_fpath: str, read_kwargs=None
) -> Tuple[DataFrame, Dict[str, float]]:
    optimize_kwargs = {
        "step_size": 0.005,
        "stopping_delta": 0.0001,
        "max_iter": 10000,
    }
    df1, df2 = load_handcrafted_survey(survey_fpath, read_kwargs)

    gft_results = DataFrame(
        goodness_of_fit_test_handcrafted(
            df1.to_numpy(), df2.to_numpy(), optimize_kwargs=optimize_kwargs
        )
    )

    gft_results.index = df1.index

    model = IIAModelHandcrafted(
        df1.shape[0], options=df1.shape[1] + 1, hp_std_v=2
    )
    _, trace, _ = model.fit(
        df1.to_numpy().astype(int),
        df2.to_numpy().astype(int),
        draws=3000,
        tune=2000,
    )
    data_dict = {
        "A": df1.to_numpy().astype(int),
        "B": df2.to_numpy().astype(int),
    }
    ppc_agg_pval = posterior_predictive_check(
        df1.shape[0],
        data_dict,
        trace.posterior,
        squared_relative_error,
        aggregate_pvals=True,
    )
    ppc_pvals = posterior_predictive_check(
        df1.shape[0],
        data_dict,
        trace.posterior,
        squared_relative_error,
        aggregate_pvals=False,
    )

    gft_results["ppc-pval"] = ppc_pvals.to_numpy()

    agg_pvals = {
        "min_ppc": ppc_pvals.min(),
        "min_gft": gft_results["p-val"].min(),
        "agg_ppc": ppc_agg_pval,
        "agg_gft": chi2.sf(
            gft_results["stat"].sum(), df=gft_results["dgf"].sum()
        ),
    }
    return gft_results, agg_pvals
