import surprise as s
import numpy as np
import pandas as pd
from reclab import data_utils

# valid datasets and the best parameters for them
DEFAULT_PARAMS = dict(n_epochs=128)
SETTINGS_PARAMS = {'ml-100k_nmf':
                       dict(reg_pu=0.08, reg_qi=0.08, biased=False),
                   'ml-100k_mf-biased':
                       dict(lr_all=0.01, reg_all=0.1, biased=True),
                   'ml-100k_mf-unbiased':
                       dict(lr_all=0.01, reg_all=0.1, biased=False),
                   'lastfm-360k_nmf':
                       dict(reg_pu=0.05, reg_qi=0.05, biased=False),
                   # from https://arxiv.org/pdf/2107.00833.pdf
                   'lastfm-360k_mf-biased': dict(lr_all=0.04777777777777778,
                                                 reg_all=0.2277777777777778,
                                                 biased=True),
                   'lastfm-360k_mf-unbiased':
                       dict(lr_all=0.04, reg_all=0.2, biased=False),
                   }


class DataGame:
    def __init__(self, setting=None, d=10, random_state=None):
        if setting is not None:
            dataset_name, model_name = setting.split('_')
            self.load_dataset(dataset_name, model_name == 'mf-unbiased')
            model_params = DEFAULT_PARAMS
            model_params.update(SETTINGS_PARAMS[setting])
            model_params.update(dict(n_factors=d, random_state=random_state))
            self.set_model(model_name == 'nmf', **model_params)
        else:
            print('Initialize data and model with load_dataset and set_model.')

    def load_dataset(self, dataset_name='ml-100k', center_ratings=False):
        rating_data, self.consumer_attributes, self.producer_attributes = (
            data_utils.get_data(dataset_name, load_attributes=True))
        if dataset_name == 'lastfm-360k':
            r = np.random.RandomState(seed=19)

            # remove 0s rating dataf
            rating_data = rating_data.drop(
                rating_data[rating_data['rating'] == 0].index)
            user_ids = np.unique(rating_data['user_id'])
            original_num_items = len(np.unique(rating_data['item_id']))
            
            # select 10% of users
            density = 0.1
            user_id_subset = r.choice(
                user_ids, size=int(density*len(user_ids)), replace=False)
            rating_data = rating_data[
                rating_data['user_id'].isin(user_id_subset)]
            # No consumer attributes for this data

            # select 10% of items
            item_ids = np.unique(rating_data['item_id'])
            item_id_subset = r.choice(
                item_ids, size=int(density*original_num_items), replace=False)
            rating_data = rating_data[
                rating_data['item_id'].isin(item_id_subset)]

            self.producer_attributes = self.producer_attributes[
                self.producer_attributes['item_id'].isin(item_id_subset)]
            add_gender_attributes_lastfm(self.producer_attributes)

        self.rating_shift = np.mean(rating_data['rating']) * center_ratings
        rating_data['rating'] -= self.rating_shift

        reader = s.Reader(rating_scale=(-np.inf, np.inf))
        self.data = s.Dataset.load_from_df(
            rating_data[['user_id', 'item_id', 'rating']], reader)
        self.algo = None

    def set_model(self, nonneg=False, **model_params):
        """
        Sets the matrix factorization model.
        model_params should include:
            - n_factors: the latent dimension
            - biased: whether to include biases in the model
            - lr_all and reg_all
            - n_epochs
            - and more: https://surprise.readthedocs.io/en/stable/matrix_factorization.html#surprise.prediction_algorithms.matrix_factorization.SVD
        """
        if nonneg:
            self.algo = s.NMF(**model_params)
        else:
            self.algo = s.SVD(**model_params)

    def evaluate_model(self, verbose=True):
        """
        Returns dict of cross validation results.
        Averages two runs of 90/10 train/test split
        """
        if self.algo is not None:
            cv = s.model_selection.split.ShuffleSplit(n_splits=2, test_size=0.1)
            return s.model_selection.cross_validate(
                self.algo, self.data, measures=['RMSE'], cv=cv, verbose=verbose)
        else:
            print("Must set_model before evaluating.")

    def fit_vectors(self, normalize=False):
        """
        Fits a preference model (based on MF) to rating data.
        Adds the inner model ids of consumer and producer vectors to the attribute dataframes.

        Returns:
            Tuple of the `n x d` producer and `m x d` consumer arrays.
        """
        if self.algo is not None:
            # training model
            trainset = self.data.build_full_trainset()
            self.algo.fit(trainset)
            consumers, producers = self.algo.pu, self.algo.qi

            # normalizing
            if normalize:
                consumers = (consumers.T / np.linalg.norm(consumers, axis=1)).T
                producers = (producers.T / np.linalg.norm(producers, axis=1)).T

            # updating attributes
            self.producer_attributes['inner_id'] = [
                trainset.to_inner_iid(item_id)
                for item_id in self.producer_attributes['item_id']]
            self.producer_attributes = (
                self.producer_attributes.set_index('inner_id'))

            if self.consumer_attributes is not None:
                self.consumer_attributes['inner_id'] = [
                    trainset.to_inner_uid(user_id)
                    for user_id in self.consumer_attributes['user_id']]
                self.consumer_attributes = (
                    self.consumer_attributes.set_index('inner_id'))
            
            return producers, consumers
        else:
            print("Must set_model before fitting.")
            return None, None

    def get_info(self, which, inner_id):
        """
        Args:
          which: 'producer' or 'consumer'
          inner_id:  integer inner id associated with index of
            the producer/consumer vector
        """
        assert which in ('consumer', 'producer')
        if which == 'consumer':
            return self.consumer_attributes.loc[inner_id]
        elif which == 'producer':
            return self.producer_attributes.loc[inner_id]


def add_gender_attributes_lastfm(df):
    # unknown gender count / male gender count / female gender count /
    # other gender count / na gender count
    genders = ['unknown', 'male', 'female', 'other', 'na']
    df[genders] = df['gender'].str.split('/', expand=True)
    for gender in genders:
        df[gender] = pd.to_numeric(df[gender])

    binary_gender = np.array([np.nan] * len(df), dtype=object)
    binary_gender[df.male > 0] = 'Male'
    binary_gender[df.female > 0] = 'Female'
    df['binary_gender'] = binary_gender

    strict = np.array([np.nan] * len(df), dtype=object)
    strict[(df.male > 0) & (df.female > 0)] = 'Mixed'
    strict[(df.male > 0) & (df.female == 0)] = 'Male'
    strict[(df.male == 0) & (df.female > 0)] = 'Female'
    df['strict_gender'] = strict
