try:
    import matplotlib.pyplot as plt
    plt.figure(figsize=(1, 1)); plt.plot(); plt.title("make sure plot shows"); plt.show()
except ImportError:
    pass

import torch, dataclasses, warnings, json
import pandas as pd
from typing import Dict, List
from rime.models import (Rand, Pop, EMA, RNN, Transformer, Hawkes, HawkesPoisson,
                         LightFM_BPR, ALS, LogisticMF, BPR, GraphConv, LDA)
from rime.models.zero_shot import BayesLM, ItemKNN
from rime.metrics import (evaluate_item_rec, evaluate_user_rec, evaluate_mtch)
from rime import dataset
from rime.dataset import Dataset
from rime.util import _argsort, cached_property, RandScore, plot_rec_results, plot_mtch_results

from pkg_resources import get_distribution, DistributionNotFound
try:
    __version__ = get_distribution("recurrent-intensity-model-experiments").version
    print("recurrent-intensity-model-experiments (rime)", __version__)
except DistributionNotFound:
    warnings.warn("rime version configuration issues in setuptools_scm")


@dataclasses.dataclass
class ExperimentResult:
    cvx: bool
    online: bool
    _k1: int
    _c1: int
    _kmax: int
    _cmax: int

    item_ppl_baseline: float = None
    user_ppl_baseline: float = None

    item_ppl: float = None  # Deprecated; will be removed
    user_ppl: float = None  # Deprecated; will be removed

    item_rec: Dict[str, Dict[str, float]] = dataclasses.field(default_factory=dict)
    user_rec: Dict[str, Dict[str, float]] = dataclasses.field(default_factory=dict)
    mtch_: Dict[str, List[Dict[str, float]]] = dataclasses.field(default_factory=dict)

    def __post_init__(self):
        if self.item_ppl_baseline is None:
            warnings.warn("item_ppl -> item_ppl_baseline", DeprecationWarning)
            self.item_ppl_baseline = self.item_ppl
        if self.user_ppl_baseline is None:
            warnings.warn("user_ppl -> user_ppl_baseline", DeprecationWarning)
            self.user_ppl_baseline = self.user_ppl

    def print_results(self):
        print('\nitem_rec')
        print(pd.DataFrame(self.item_rec).T)
        print('\nuser_rec')
        print(pd.DataFrame(self.user_rec).T)

        mtch1 = self.get_mtch_(self._k1, self._c1)
        if mtch1 is not None:
            print('\nmtch_')
            print(mtch1.T)

    def save_results(self, fn):
        with open(fn, 'w') as fp:
            json.dump(dataclasses.asdict(self), fp)

    def get_mtch_(self, k=None, c=None, name="mtch_"):
        y = {}
        for method, x in getattr(self, name).items():
            x = pd.DataFrame(x)
            if k is not None and c is not None:
                y[method] = x.set_index(['k', 'c']).loc[(k, c)]
            elif k is not None:
                y[method] = x.set_index(['k', 'c']).loc[k].sort_index().T
            elif c is not None:
                y[method] = x.set_index(['c', 'k']).loc[c].sort_index().T
            else:
                raise ValueError("either k or c must be provided")
        return pd.concat(y, axis=1) if len(y) else None


class Experiment:
    """ Produce item_rec / user_rec metrics;
    then sweeps through multipliers for relevance-diversity curve,
    interpreting mult<1 as item min-exposure and mult>=1 as user max-limit
    """
    def __init__(
        self, D, V=None, *V_extra,
        mult=[],  # [0, 0.1, 0.2, 0.5, 1, 3, 10, 30, 100],
        models_to_run=None,
        model_hyps={},
        device="cuda" if torch.cuda.is_available() else "cpu",
        cvx=False,
        online=False,
        tie_break=0,
        cache=None,
        results=None,
        **mtch_kw
    ):

        self.D = D
        self.V = V
        self.V_extra = V_extra

        self.mult = mult

        if models_to_run is None:
            models_to_run = self.registered.keys()
        self.models_to_run = models_to_run

        self.model_hyps = model_hyps
        self.device = device

        if online:
            if not cvx:
                warnings.warn("online requires cvx, resetting cvx to True")
                cvx = True
            assert V is not None, "online cvx is trained with explicit valid_mat"

        self.tie_break = tie_break
        if cache is not None:
            self.update_cache(cache)
        self.mtch_kw = mtch_kw

        if results is None:
            results = ExperimentResult(
                cvx, online,
                _k1=self.D.default_item_rec_top_k,
                _c1=self.D.default_user_rec_top_c,
                _kmax=len(self.D.item_in_test),
                _cmax=len(self.D.user_in_test),
                item_ppl_baseline=self.D.item_ppl_baseline,
                user_ppl_baseline=self.D.user_ppl_baseline,
            )
        self.results = results

        # pass-through references
        self.__dict__.update(self.results.__dict__)
        self.print_results = self.results.print_results
        self.get_mtch_ = self.results.get_mtch_

    def metrics_update(self, name, S, T=None):
        target_csr = self.D.target_csr
        score_mat = S

        if self.online:
            valid_mat = T
        elif self.cvx:
            valid_mat = score_mat
        else:
            valid_mat = None

        self.item_rec[name] = evaluate_item_rec(
            target_csr, score_mat, self._k1, device=self.device) \
            if getattr(self, "_run_item_rec", True) else None
        self.user_rec[name] = evaluate_user_rec(
            target_csr, score_mat, self._c1, device=self.device) \
            if getattr(self, "_run_user_rec", True) else None

        print(pd.DataFrame({
            'item_rec': self.item_rec[name],
            'user_rec': self.user_rec[name],
        }).T)

        if len(self.mult):
            self.mtch_[name] = self._mtch_update(target_csr, score_mat, valid_mat, name)

    def _mtch_update(self, target_csr, score_mat, valid_mat, name):
        """ assign user/item matches and return evaluation results.
        """
        confs = []
        for m in self.mult:
            if m < 1:
                # lower-bound is interpreted as item min-exposure
                confs.append((self._k1, self._c1 * m, 'lb'))
            else:
                # upper-bound is interpreted as user max-limit
                confs.append((self._k1 * m, self._c1, 'ub'))

        mtch_kw = self.mtch_kw.copy()
        if self.cvx:
            mtch_kw['valid_mat'] = valid_mat
            mtch_kw['prefix'] = f"{name}-{self.online}"
        else:
            mtch_kw['argsort_ij'] = _argsort(score_mat, device=self.device)

        out = []
        for k, c, constraint_type in confs:
            res = evaluate_mtch(
                target_csr, score_mat, k, c, constraint_type=constraint_type,
                cvx=self.cvx, device=self.device,
                item_prior=1 + self.D.item_in_test['_hist_len'].values,
                **mtch_kw
            )
            res.update({'k': k, 'c': c})
            out.append(res)

        return out

    @cached_property
    def registered(self):
        registered = {
            "Rand": lambda D: Rand().transform(D),
            "Pop": lambda D: self._pop.transform(D),
            "EMA": lambda D: EMA(D.horizon).transform(D) * self._pop_item.transform(D),
            "Hawkes": lambda D: self._hawkes.transform(D) * self._pop_item.transform(D),
            "HP": lambda D: self._hawkes_poisson.transform(D) * self._pop_item.transform(D),

            "RNN": lambda D: self._rnn.transform(D),
            "RNN-Pop": lambda D: self._rnn.transform(D) * Pop(1, 0).transform(D),
            "RNN-EMA": lambda D: self._rnn.transform(D) * EMA(D.horizon).transform(D),
            "RNN-Hawkes": lambda D: self._rnn.transform(D) * self._hawkes.transform(D),
            "RNN-HP": lambda D: self._rnn.transform(D) * self._hawkes_poisson.transform(D),

            "Transformer": lambda D: self._transformer.transform(D),
            "Transformer-Pop": lambda D: self._transformer.transform(D) * Pop(1, 0).transform(D),
            "Transformer-EMA": lambda D: self._transformer.transform(D) * EMA(D.horizon).transform(D),
            "Transformer-Hawkes": lambda D: self._transformer.transform(D) * self._hawkes.transform(D),
            "Transformer-HP": lambda D: self._transformer.transform(D) * self._hawkes_poisson.transform(D),

            "BPR-Item": lambda D: self._bpr_item.transform(D),
            "BPR-User": lambda D: self._bpr_user.transform(D),
            "BPR": lambda D: self._bpr.transform(D),

            "GraphConv-Base": lambda D: self._graph_conv_base.transform(D),
            "GraphConv-Extra": lambda D: self._graph_conv_extra.transform(D),

            "LDA": lambda D: self._lda.transform(D),

            "ALS": lambda D: self._als.transform(D),
            "LogisticMF": lambda D: self._logistic_mf.transform(D),

            "BayesLM-0": lambda D: self._bayes_lm_0.transform(D),
            "BayesLM-1": lambda D: self._bayes_lm_1.transform(D),

            "ItemKNN-0": lambda D: self._item_knn_0.transform(D),
            "ItemKNN-1": lambda D: self._item_knn_1.transform(D),
        }

        # disable models due to missing inputs

        if not ('TEST_START_TIME' in self.D.user_in_test and '_hist_ts' in self.D.user_in_test
                and self.D.horizon < float("inf")):
            warnings.warn("disabling temporal models due to missing TEST_START_TIME, _hist_ts or horizon")
            for model in ['EMA', 'Hawkes', 'HP', 'RNN-EMA', 'RNN-Hawkes', 'RNN-HP',
                           'Transformer-EMA', 'Transformer-Hawkes', 'Transformer-HP']:
                registered.pop(model, None)

        if self.V is None:
            warnings.warn("disabling HP and GraphConv due to missing validation set")
            for model in ['HP', 'RNN-HP', 'Transformer-HP',
                           'GraphConv-Base', 'GraphConv-Extra']:
                registered.pop(model, None)

        if 'TITLE' not in self.D.training_data.item_df:
            warnings.warn("disabling zero-shot models due to missing item TITLE")
            for model in ['BayesLM-0', 'BayesLM-1', 'ItemKNN-0', 'ItemKNN-1']:
                registered.pop(model, None)

        return registered

    def _validate_run_input(self, models_to_run):
        """ return a dictionary of {model_name: model_str or model_obj} """
        if models_to_run is None:
            models_to_run = [m for m in self.models_to_run if m not in
                             ['BayesLM-0', 'BayesLM-1', 'ItemKNN-0', 'ItemKNN-1']]
        if isinstance(models_to_run, str):
            models_to_run = [models_to_run]
        if isinstance(models_to_run, list):
            for model in models_to_run:
                assert model in self.registered, f"{model} disabled or unregistered"
                print("models to run", models_to_run)
            models_to_run = {k: k for k in models_to_run}
        return models_to_run

    def run(self, models_to_run=None):
        """ models_to_exclude is ignored if models_to_run is explicitly provided """

        models_to_run = self._validate_run_input(models_to_run)

        for model_name, model_obj in models_to_run.items():
            print("running", model_name)
            if isinstance(model_obj, str):
                transform_fn = self.registered[model_obj]
            else:
                transform_fn = model_obj.transform

            S = transform_fn(self.D)
            if self.D.prior_score is not None:
                S = S + self.D.prior_score
            if self.tie_break:
                warnings.warn("Using experimental RandScore class")
                S = S + RandScore.like(S) * self.tie_break

            if self.online:
                V = self.V.reindex(self.D.item_in_test.index, axis=1)
                T = transform_fn(V)
                if V.prior_score is not None:
                    T = T + V.prior_score
                if self.tie_break:
                    warnings.warn("Using experimental RandScore class")
                    T = T + RandScore.like(T) * self.tie_break
            else:
                T = None
            self.metrics_update(model_name, S, T)

    @cached_property
    def _pop(self):
        return Pop().fit(self.D.training_data)

    @cached_property
    def _pop_item(self):
        return Pop(user_rec=False, item_rec=True).fit(self.D.training_data)

    @cached_property
    def _rnn(self):
        return RNN(
            self.D.training_data.item_df, **self.model_hyps.get("RNN", {})
        ).fit(self.D.training_data)

    @cached_property
    def _transformer(self):
        return Transformer(
            self.D.training_data.item_df, **self.model_hyps.get("Transformer", {})
        ).fit(self.D.training_data)

    @cached_property
    def _hawkes(self):
        return Hawkes(self.D.horizon).fit(self.D.training_data)

    @cached_property
    def _hawkes_poisson(self):
        assert self.V is not None, "_hawkes_poisson requires self.V"
        return HawkesPoisson(self._hawkes).fit(self.V)

    @cached_property
    def _bpr_item(self):
        return LightFM_BPR(item_rec=True).fit(self.D.training_data)

    @cached_property
    def _bpr_user(self):
        return LightFM_BPR(user_rec=True).fit(self.D.training_data)

    @cached_property
    def _bpr(self):
        return BPR(**self.model_hyps.get("BPR", {})).fit(self.D.training_data)

    @cached_property
    def _graph_conv_base(self):
        assert self.V is not None, "_graph_conv_base requires self.V"
        return GraphConv(
            self.D, **self.model_hyps.get("GraphConv-Base", {})
        ).fit(self.V)

    @cached_property
    def _graph_conv_extra(self):
        if len(self.V_extra) == 0:
            warnings.warn("without V_extra, we are defaulting to _graph_conv_base")
            return self._graph_conv_base

        return GraphConv(
            self.D, **self.model_hyps.get("GraphConv-Extra", {})
        ).fit(self.V, *self.V_extra)

    @cached_property
    def _lda(self):
        return LDA(
            self.D.training_data, **self.model_hyps.get("LDA", {})
        ).fit(self.D.training_data)

    @cached_property
    def _als(self):
        return ALS().fit(self.D.training_data)

    @cached_property
    def _logistic_mf(self):
        return LogisticMF().fit(self.D.training_data)

    @cached_property
    def _bayes_lm_0(self):
        return BayesLM(self.D.training_data.item_df, item_pop_power=0,
                        **self.model_hyps.get("BayesLM-0", {}))

    @cached_property
    def _bayes_lm_1(self):
        return BayesLM(self.D.training_data.item_df, item_pop_power=1,
                        **self.model_hyps.get("BayesLM-1", {}))

    @cached_property
    def _item_knn_0(self):
        return ItemKNN(self.D.training_data.item_df, item_pop_power=0,
                        **self.model_hyps.get("ItemKNN-0", {}))

    @cached_property
    def _item_knn_1(self):
        return ItemKNN(self.D.training_data.item_df, item_pop_power=1,
                        **self.model_hyps.get("ItemKNN-1", {}))

    def update_cache(self, other):
        for attr in ['registered', '_transformer', '_rnn', '_hawkes', '_hawkes_poisson',
                     '_bpr_item', '_bpr_user', '_als', '_logistic_mf',
                     '_bpr', '_graph_conv_base', '_graph_conv_extra', '_lda']:
            if attr in other.__dict__:
                setattr(self, attr, getattr(other, attr))


def main(name, *args, **kw):
    prepare_fn = getattr(dataset, name)
    D, *V = prepare_fn(*args)
    self = Experiment(D, *V, **kw)
    self.run()
    self.results.print_results()
    return self
