from __future__ import division
from __future__ import print_function

import time

import dgl
import torch
import torch.nn as nn
import numpy as np
import dgl.data

import collections

from data_splitter import split_data

import json


seed = 0
dgl.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

dataset = "wyze"

sys_device = 'cpu'#torch.device('cuda' if torch.cuda.is_available() else 'cpu')#

#user_graphs, all_trigger_actions, all_devices, user_device_id_to_node_id = load_data_multi_devices(dataset,sys_device)




user_graphs_list, _ = dgl.load_graphs("wyze_rule/usergraphs.bin")

user_id_dict = json.load(open('wyze_rule/user_id_dict.json', 'r'))
user_graphs = dict()
user_ids = list(user_id_dict.keys())

for i in range(len(user_ids)):
    user_graphs[user_ids[i]] = user_graphs_list[i]
    

    
all_trigger_actions = json.load(open('wyze_rule/all_trigger_actions.json', 'r'))
all_devices = json.load(open('wyze_rule/all_devices.json', 'r'))
user_device_id_to_node_id = json.load(open('wyze_rule/user_device_id_to_node_id.json', 'r'))


train_gs, train_pos_gs, train_neg_gs, test_pos_gs, test_neg_gs = split_data(user_graphs, all_trigger_actions, False)


class MF(nn.Module):
    """ Matrix factorization model simple """
    def __init__(self, num_users, num_items, emb_dim):
        super().__init__()
        self.user_emb = nn.Embedding(num_embeddings=num_users, embedding_dim=emb_dim)
        self.item_emb = nn.Embedding(num_embeddings=num_items, embedding_dim=emb_dim)
    def forward(self, user, item):
        user_emb = self.user_emb(user)
        item_emb = self.item_emb(item)
        element_product = (user_emb*item_emb).sum(1)
        return element_product

def sigmoid_range(x, low, high):
    """ Sigmoid function with range (low, high) """
    return torch.sigmoid(x) * (high-low) + low

class MFAdvanced(nn.Module):
    """ Matrix factorization + user & item bias, weight init., sigmoid_range """
    def __init__(self, num_users, num_items, emb_dim, init, bias, sigmoid):
        super().__init__()
        self.bias = bias
        self.sigmoid = sigmoid
        self.user_emb = nn.Embedding(num_users, emb_dim)
        self.item_emb = nn.Embedding(num_items, emb_dim)
        if bias:
            self.user_bias = nn.Parameter(torch.zeros(num_users))
            self.item_bias = nn.Parameter(torch.zeros(num_items))
            self.offset = nn.Parameter(torch.zeros(1))
        if init:
            self.user_emb.weight.data.uniform_(0., 0.05)
            self.item_emb.weight.data.uniform_(0., 0.05)
    def forward(self, user, item):
        user_emb = self.user_emb(user)
        item_emb = self.item_emb(item)
        element_product = (user_emb*item_emb).sum(1)
        if self.bias:
            user_b = self.user_bias[user]
            item_b = self.item_bias[item]
            element_product += user_b + item_b + self.offset
        if self.sigmoid:
            return sigmoid_range(element_product, 0, 1)
        return element_product
        
def centralized_MF():
    user_index_dict_MF = dict()
    count = 0
    for key in user_graphs.keys():
        user_index_dict_MF[key] = count
        count += 1

    rules_dict = dict()
    for key in user_graphs:
        current_g = user_graphs[key]
        node_types = torch.argmax(current_g.ndata['feat'], axis = 1)
        edges = current_g.edges()
        edge_types = current_g.edata['etype']
        for i in range(current_g.number_of_edges()):
            source, target = edges[0][i], edges[1][i]
            rule = f'{node_types[source]}_{edge_types[i]}_{node_types[target]}'
            rules_dict[rule] = 1

    count = 0
    for key in rules_dict.keys():
        rules_dict[key] = count
        count += 1

    train_user_rule_MF = []

    for key in train_pos_gs:
        #train_pos has no nfeat data
        node_types = torch.argmax(train_gs[key].ndata['feat'], axis = 1)

        current_g = train_pos_gs[key]
        edges = current_g.edges()
        edge_types = current_g.edata['etype']
        for i in range(current_g.number_of_edges()):
            source, target = edges[0][i], edges[1][i]
            rule = f'{node_types[source]}_{edge_types[i]}_{node_types[target]}'
            train_user_rule_MF.append([user_index_dict_MF[key], rules_dict[rule]])


    test_user_rule_MF = []

    for key in test_pos_gs:
        #test_pos has no nfeat data
        node_types = torch.argmax(train_gs[key].ndata['feat'], axis = 1)

        current_g = test_pos_gs[key]
        edges = current_g.edges()
        edge_types = current_g.edata['etype']
        for i in range(current_g.number_of_edges()):
            source, target = edges[0][i], edges[1][i]
            rule = f'{node_types[source]}_{edge_types[i]}_{node_types[target]}'
            test_user_rule_MF.append([user_index_dict_MF[key], rules_dict[rule]])

    train_user_rule_MF = torch.tensor(train_user_rule_MF).to(sys_device)

    test_user_rule_MF = torch.tensor(test_user_rule_MF).to(sys_device)

    CFG = {
        'sigmoid': True,
        'bias': True,
        'init': True,
        'lr': 0.005,
        'num_epochs': 10,
    }

    n_users = len(user_index_dict_MF)
    n_items = len(rules_dict)


    model = MFAdvanced(n_users, n_items, emb_dim=32,
                     init=CFG['init'],
                     bias=CFG['bias'],
                     sigmoid=CFG['sigmoid'],
    )
    
    
    #model = MF(n_users, n_items, emb_dim=32)
    model.to(sys_device)
    
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    for e in range(200):
        model.train()
        # forward
        
        train_user_vector = train_user_rule_MF[:,0]
        train_item_vector = train_user_rule_MF[:,1]

        preds = model(train_user_vector, train_item_vector)
        yRatings = torch.ones(train_user_rule_MF[:,0].shape).to(sys_device)
        loss = loss_fn(preds, yRatings)
        
        
        optimizer.zero_grad()
        loss.backward()
        print("train loss", loss)
        optimizer.step()
        
        model.eval()
        
        test_user_vector = test_user_rule_MF[:,0]
        test_item_vector = test_user_rule_MF[:,1]
        
        preds = model(test_user_vector, test_item_vector)
        yRatings = torch.ones(test_user_rule_MF[:,0].shape).to(sys_device)
        loss = loss_fn(preds, yRatings)
        print("test loss", loss)
    
    
    
    #get recommendation after removing training rules
    train_user_rule_dict = collections.defaultdict(set)
    for user, rule in train_user_rule_MF:
        train_user_rule_dict[int(user)].add(rule)

    recommendation_dict = dict()
    for i in range(n_users):
        test_preds = model(torch.tensor([i] * n_items).to(sys_device), torch.tensor(range(n_items)).to(sys_device))
        sorted_tensor, indices = torch.sort(test_preds, descending = True)
        train_rules = train_user_rule_dict[i]
        indices = torch.tensor([i for i in indices if i not in train_rules])
        
        recommendation_dict[i] = indices[:50]
    
    #test hit rate
    for rec_num in list(range(5, 51, 5)):
        hits = 0
        for user, rule in test_user_rule_MF:
            rank = (recommendation_dict[int(user)][:rec_num] == rule).nonzero()
            if len(rank) != 0:
                hits += 1
        print(hits / len(test_user_rule_MF))
        
        

centralized_MF()
