from collections.abc import Callable
from itertools import product
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import pymc as pm  # type: ignore
from utils import range_without as _range_without  # type: ignore
from xarray import Dataset


class IIAModel:

    def __init__(self, n_questions: int, max_options: int, hp_std_v: float):
        self.n_questions = n_questions
        self.max_options = max_options
        self.likelihood: Optional[pm.Distribution] = None

        with pm.Model(
            coords={
                "question": np.arange(n_questions),
                "option": np.arange(max_options),
            }
        ) as model:

            std_v = pm.HalfNormal("std_v", sigma=hp_std_v)
            v = pm.Normal("v", mu=0, sigma=std_v, dims=["question", "option"])
            self.p = pm.Deterministic(
                "p", pm.math.softmax(v, axis=1), dims=["question", "option"]
            )

        self.model: pm.Model = model

    def fit(
        self, data: np.ndarray, chain=4, tune=1000, draws=1000
    ) -> Tuple[Any, Any, Any]:
        n_participants = data.sum(axis=1)

        with self.model:

            self.likelihood = pm.Multinomial(
                "counts",
                n=n_participants,
                p=self.p,
                observed=data,
                dims=("question", "option"),
            )

            prior_pred = pm.sample_prior_predictive()
            trace = pm.sample(tune=tune, draws=draws, chains=chain)
            post_pred = pm.sample_posterior_predictive(trace=trace)

        return prior_pred, trace, post_pred


class IIAModelHandcrafted:

    def __init__(self, n_questions: int, options: int, hp_std_v: float):
        self.n_questions = n_questions
        self.options = options
        self.likelihoods: Optional[Tuple[pm.Distribution, pm.Distribution]] = (
            None
        )

        with pm.Model(
            coords={
                "question": np.arange(n_questions),
                "option": np.arange(options),
                "option_A": np.arange(options - 1),
                "option_B": np.arange(options - 1),
            }
        ) as model:

            std_v = pm.HalfNormal("std_v", sigma=hp_std_v)
            v = pm.Normal("v", mu=0, sigma=std_v, dims=["question", "option"])

            self.ps = {
                "p_A": pm.Deterministic(
                    "p_A",
                    pm.math.softmax(v[:, _range_without(options, 3)], axis=1),
                    dims=["question", "option_A"],
                ),
                "p_B": pm.Deterministic(
                    "p_B",
                    pm.math.softmax(v[:, _range_without(options, 2)], axis=1),
                    dims=["question", "option_B"],
                ),
            }
        self.model: pm.Model = model

    def fit(
        self,
        res_A: np.ndarray,
        res_B: np.ndarray,
        chain: int = 4,
        tune: int = 1000,
        draws: int = 1000,
    ) -> Tuple[Any, Any, Any]:

        n_participants_A = res_A.sum(axis=1)
        n_participants_B = res_B.sum(axis=1)

        with self.model:

            self.likelihoods = (
                pm.Multinomial(
                    "counts_A",
                    n=n_participants_A,
                    p=self.ps["p_A"],
                    observed=res_A,
                    dims=("question", "option_A"),
                ),
                pm.Multinomial(
                    "counts_B",
                    n=n_participants_B,
                    p=self.ps["p_B"],
                    observed=res_B,
                    dims=("question", "option_B"),
                ),
            )

            prior_pred = pm.sample_prior_predictive()
            trace = pm.sample(tune=tune, draws=draws, chains=chain)
            post_pred = pm.sample_posterior_predictive(trace=trace)

        return prior_pred, trace, post_pred


class IIAModelLeaveOneOut:

    def __init__(self, n_questions: int, max_options: int, hp_std_v: float):
        self.n_questions = n_questions
        self.max_options = max_options
        self.likelihoods: Optional[Dict[str, pm.Distribution]] = None

        with pm.Model(
            coords={
                "question": np.arange(n_questions),
                "option_full": np.arange(max_options),
                **{
                    f"option_rem_{i}": np.arange(max_options - 1)
                    for i in range(max_options)
                },
            }
        ) as model:

            std_v = pm.HalfNormal("std_v", sigma=hp_std_v)
            v = pm.Normal(
                "v", mu=0, sigma=std_v, dims=["question", "option_full"]
            )

            self.ps = {
                "p_full": pm.Deterministic(
                    "p_full",
                    pm.math.softmax(v, axis=1),
                    dims=["question", "option_full"],
                ),
                **{
                    f"p_rem_{i}": pm.Deterministic(
                        f"p_rem_{i}",
                        pm.math.softmax(v[:, _range_without(4, i)], axis=1),
                        dims=["question", f"option_rem_{i}"],
                    )
                    for i in range(max_options)
                },
            }
        self.model: pm.Model = model

    def fit(
        self,
        res_full: np.ndarray,
        res_rems: List[np.ndarray],
        mcmc_params: Optional[Dict[str, Any]] = None,
    ):
        if mcmc_params is None:
            mcmc_params = dict()
        n_participants_full = res_full.sum(axis=1)
        n_participants_rem = [r.sum(axis=1) for r in res_rems]

        with self.model:

            self.likelihoods = {
                "likelihood_full": pm.Multinomial(
                    "counts_full",
                    n=n_participants_full,
                    p=self.ps["p_full"],
                    observed=res_full,
                    dims=("question", "option_full"),
                ),
                **{
                    f"counts_{i}": pm.Multinomial(
                        f"counts_rem_{i}",
                        n=N_rem,
                        p=self.ps[f"p_rem_{i}"],
                        observed=res_rems[i],
                        dims=("question", f"option_rem_{i}"),
                    )
                    for i, N_rem in enumerate(n_participants_rem)
                },
            }

            # prior_pred = pm.sample_prior_predictive()
            trace = pm.sample(**mcmc_params)
            # post_pred = pm.sample_posterior_predictive(trace=trace)

        return trace


class AdditiveNoiseModelLeaveOneOut:

    def __init__(
        self,
        n_questions: int,
        max_options: int,
        hp_std_v: float,
        hp_std_c: float,
    ):
        self.n_questions = n_questions
        self.max_options = max_options
        self.likelihoods: Optional[Dict[str, pm.Distribution]] = None

        with pm.Model(
            coords={
                "question": np.arange(n_questions),
                "option_full": np.arange(max_options),
                **{
                    f"option_rem_{i}": np.arange(max_options - 1)
                    for i in range(max_options)
                },
            }
        ) as model:

            std_v = pm.HalfNormal("std_v", sigma=hp_std_v)
            std_c = pm.HalfNormal("std_c", sigma=hp_std_c)

            v = pm.Normal(
                "v", mu=0, sigma=std_v, dims=["question", "option_full"]
            )

            context_noises = [
                pm.Normal(
                    f"c_rem_{i}",
                    mu=0,
                    sigma=std_c,
                    dims=["question", f"option_rem_{i}"],
                )
                for i in range(max_options)
            ]

            self.ps = {
                "p_full": pm.Deterministic(
                    "p_full",
                    pm.math.softmax(v, axis=1),
                    dims=["question", "option_full"],
                ),
                **{
                    f"p_rem_{i}": pm.Deterministic(
                        f"p_rem_{i}",
                        pm.math.softmax(
                            v[:, _range_without(4, i)] + context_noises[i],
                            axis=1,
                        ),
                        dims=["question", f"option_rem_{i}"],
                    )
                    for i in range(max_options)
                },
            }
        self.model = model

    def fit(
        self,
        res_full: np.ndarray,
        res_rems: List[np.ndarray],
        chain: int = 4,
        tune: int = 1000,
        draws: int = 1000,
    ) -> Tuple[Any, Any, Any]:

        n_participants_full = res_full.sum(axis=1)
        n_participants_rem = [r.sum(axis=1) for r in res_rems]

        with self.model:

            self.likelihoods = {
                "likelihood_full": pm.Multinomial(
                    "counts_full",
                    n=n_participants_full,
                    p=self.ps["p_full"],
                    observed=res_full,
                    dims=("question", "option_full"),
                ),
                **{
                    f"counts_{i}": pm.Multinomial(
                        f"counts_rem_{i}",
                        n=N_rem,
                        p=self.ps[f"p_rem_{i}"],
                        observed=res_rems[i],
                        dims=("question", f"option_rem_{i}"),
                    )
                    for i, N_rem in enumerate(n_participants_rem)
                },
            }

            prior_pred = pm.sample_prior_predictive()
            trace = pm.sample(tune=tune, draws=draws, chains=chain)
            post_pred = pm.sample_posterior_predictive(trace=trace)

        return prior_pred, trace, post_pred


def squared_relative_error(
    n_participants: np.ndarray, observed: np.ndarray, ps: np.ndarray
) -> np.ndarray:
    expected = ps * n_participants.reshape(-1, 1)
    error = (expected - observed) ** 2 / expected

    return error.sum(axis=1)


def nll_error(
    n_participants: np.ndarray, observed: np.ndarray, ps: np.ndarray
) -> np.ndarray:
    expected = ps * n_participants.reshape(-1, 1)

    nll = observed * np.log(observed / expected)
    return nll.sum(axis=1)


def _posterior_predictive_check_sample(
    n_questions: int,
    data: Dict[str, np.ndarray],
    posterior_sample: Dataset,
    stat_func: Callable[[int, np.ndarray, np.ndarray], np.ndarray],
    rng: np.random.Generator,
) -> pd.DataFrame:

    stat = np.zeros(n_questions)
    stat_repl = np.zeros(n_questions)

    for conf in data.keys():
        n_participants = data[conf].sum(axis=1)

        ps = posterior_sample[f"p_{conf}"].values
        data_repl = rng.multinomial(n_participants, pvals=ps)

        stat += stat_func(n_participants, data[conf], ps)
        stat_repl += stat_func(n_participants, data_repl, ps)

    df = pd.DataFrame(
        np.stack([np.arange(n_questions), stat, stat_repl], axis=-1),
        columns=["question", "stat", "stat_repl"],
    )

    return df


def posterior_predictive_check(
    n_questions: int,
    data: Dict[str, np.ndarray],
    posterior: Dataset,
    stat_func: Callable[[int, np.ndarray, np.ndarray], np.ndarray],
    aggregate_pvals: bool = False,
    agg: str = "sum",
) -> float | pd.Series:

    rng = np.random.default_rng()

    dfs = []
    for chain in posterior.chain.values:
        for draw in posterior.draw.values:

            posterior_sample = posterior.sel(chain=chain, draw=draw)
            df = _posterior_predictive_check_sample(
                n_questions, data, posterior_sample, stat_func, rng
            )
            df["chain"] = chain
            df["draw"] = draw
            dfs.append(df)

    df_concat = pd.concat(dfs)

    if aggregate_pvals:
        df_agg_stat = df_concat.groupby(["chain", "draw"])[
            ["stat", "stat_repl"]
        ].agg(agg)
        return (df_agg_stat["stat_repl"] > df_agg_stat["stat"]).mean()
    else:
        df_concat["check"] = df_concat["stat_repl"] > df_concat["stat"]
        return df_concat.groupby("question")["check"].mean()


def user_ic(post_sample, user_answers):

    ps = post_sample.p.sel(
        question=user_answers["question_order"].values
    ).values
    rep = user_answers[[0, 1, 2, 3]].values
    ics = -np.log((rep * ps).sum(axis=1))
    return ics.sum()


def ph_stat(post_sample, answers_df):
    ics = []

    for user in answers_df["user_id"].unique():
        ics.append(
            user_ic(post_sample, answers_df[answers_df["user_id"] == user])
        )

    return max(ics) - min(ics)


def user_ic_sim(post_sample, user_answers, rng):
    ps = post_sample.p.sel(
        question=user_answers["question_order"].values
    ).values
    rep = rng.multinomial(np.ones(user_answers.shape[0]).astype(int), ps)
    ics = -np.log((rep * ps).sum(axis=1))
    return ics.sum()


def ph_stat_sim(post_sample, answers_df):
    rng = np.random.default_rng()
    ics = []

    for user in answers_df["user_id"].unique():
        ics.append(
            user_ic_sim(
                post_sample, answers_df[answers_df["user_id"] == user], rng
            )
        )

    return max(ics) - min(ics)


def ph_posterior_predictive_check(posterior, df_ans):
    rows = []
    it = list(product(posterior.draw, posterior.chain))
    for draw, chain in it:

        post_samp = posterior.sel(draw=draw, chain=chain)
        t_stat = ph_stat(post_samp, df_ans)
        t_stat_rep = ph_stat_sim(post_samp, df_ans)

        rows.append(
            {
                "chain": int(chain.values),
                "draw": int(draw.values),
                "stat": t_stat,
                "stat_rep": t_stat_rep,
            }
        )
    return pd.DataFrame(rows)
