import numpy as np
import io
import re
import string
import tensorflow as tf
import tqdm
import sys
import pickle
from datetime import datetime
from tensorflow.keras import Model
from tensorflow.keras.layers import Dot, Embedding, Flatten
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

from collections import Counter


def generate_cl_trn_dat(seqs, ws, num_ns, vocab_size, save_last=True, SEED=23333):
    """
    generate tiplets for contrastive-learning of item emb, in a word2vec fashion
    input: seqs: [[item1,...,itemk]] # list of item sequences
           ws: window size that defines contextually similar items
           num_ns: number of negative samples
           save_last: whether the last item of the sequence will be used for contrastive-learning
                      # set to false for tuning
    output: targets: n*1 tensor of target items
            contexts: n*(1+num_ns)*1 tensor of context items (1 positive + num_ns negative)
            labels: n*(1+num_ns)*1 tensor of labels, each with [1, [0]*num_ns]
    """
    targets, contexts, labels = [], [], []
    sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)
    
    for seq in seqs:
        if save_last:
            seq = seq[:-1]
        positive_skip_grams, _ = tf.keras.preprocessing.sequence.skipgrams(
          seq,
          vocabulary_size=vocab_size,
          sampling_table=sampling_table,
          window_size=ws,
          negative_samples=0)
        
        for target_item, context_item in positive_skip_grams:
            context_class = tf.expand_dims(
              tf.constant([context_item], dtype="int64"), 1)
            negative_sampling_candidates, _, _ = tf.random.log_uniform_candidate_sampler(
              true_classes=context_class,
              num_true=1,
              num_sampled=num_ns,
              unique=True,
              range_max=vocab_size,
              seed=SEED,
              name="negative_sampling")
            
            negative_sampling_candidates = tf.expand_dims(
                      negative_sampling_candidates, 1)

            context = tf.concat([context_class, negative_sampling_candidates], 0)
            label = tf.constant([1] + [0]*num_ns, dtype="int64")

            targets.append(target_item)
            contexts.append(context)
            labels.append(label)

    return targets, contexts, labels





def ml_rating_preprocess(fname:str):
    """
    preprocess movielens-1m dat
    input: ML-1M raw data file loc
    output: user_hist: dict[user_id:[item1,...,item_k]] # for building seq reco model
            itemid2idx: dict[item_id:item_idx] # coverting id to consecutive index that starts from 0
    """
    user_hist = {}
    itemfreq = {}
    with open(fname, 'r') as f:
        for line in f:
            line = line.strip().split('::')
            line = list(map(int, line))
            try:
                itemfreq[line[0]] += 1
            except KeyError:
                itemfreq[line[0]] = 0
            try:
                user_hist[line[0]].append([line[1], line[-1]])
            except KeyError:
                user_hist[line[0]] = [[line[1], line[-1]]]
    
    
    itemfreq = {k: v for k, v in sorted(itemfreq.items(), key=lambda item: item[1])}
    itemid2idx = dict(zip(list(itemfreq.keys())[::-1], range(len(itemfreq))))
    itemfreq = {itemid2idx[k]:v for k,v in itemfreq.items()}
    
    for k,v in user_hist.items():
        try:
            hist = sorted(v, key=lambda x:x[1])
        except:
            print(v)
        user_hist[k] = [itemid2idx[v[0]] for v in hist]
        
    return user_hist, itemid2idx




def ml_item_label(dat_loc, emb, itemid2idx):
    """
    generating movies' genre classficaition dataset
    """
    cat = {}
    with open(dat_loc, 'r', encoding = "ISO-8859-1") as f:
        for line in f:
            line = line.strip('\n')
            if "|" in line:
                item_cat = line.split("|")[-1]
            else:
                item_cat = line.split("::")[-1]
            itemid = int(line.split("::")[0])
            cat[itemid2idx[itemid]] = item_cat
            
    unique_cat = np.unique(list(cat.values()))
    cat2id = dict(zip(unique_cat, range(len(unique_cat))))
    emb_clf = []
    label = []
    for i in range(len(itemid2idx)):
        if i in cat:
            label.append(cat2id[cat[i]])
            emb_clf.append(emb[i])

    emb_clf = np.array(emb_clf)
    
    return emb_clf, label

