import os
import numpy as np
from time import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from utility.parser import parse_args
from utility.norm import build_sim, build_knn_normalized_graph, get_dense_laplacian
args = parse_args()

class VBPR(nn.Module):
    def __init__(self, n_users, n_items, embedding_dim, image_weight, text_weight, image_feats=None, adj=None, edge_index=None):
        super().__init__()
        self.n_users = n_users
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.user_embedding = nn.Embedding(n_users, embedding_dim)
        self.item_embedding = nn.Embedding(n_items, embedding_dim)
        self.user_image_preference = nn.Embedding(n_users, embedding_dim)

        nn.init.xavier_uniform_(self.user_embedding.weight)
        nn.init.xavier_uniform_(self.item_embedding.weight)
        nn.init.xavier_uniform_(self.user_image_preference.weight)

        self.image_feats = torch.Tensor(image_feats).cuda()

        self.image_trs = nn.Linear(image_feats.shape[1], embedding_dim)



    def forward(self, training=1):
        id_feats = self.item_embedding.weight
        image_feats = self.image_trs(self.image_feats)
        user_embed = torch.cat([self.user_embedding.weight, self.user_image_preference.weight], dim=-1)
        item_embed = torch.cat([id_feats, image_feats], dim=-1)
        return user_embed, item_embed


class DeepStyle(nn.Module):
    def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats=None, adj=None, edge_index=None):
        super().__init__()
        self.n_users = n_users
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.user_embedding = nn.Embedding(n_users, embedding_dim)
        self.item_embedding = nn.Embedding(n_items, embedding_dim)       
        self.item_image_categorical_embedding = nn.Embedding(n_items, embedding_dim)

        self.image_feats = torch.Tensor(image_feats).cuda()
        self.image_trs = nn.Linear(image_feats.shape[1], embedding_dim)

        nn.init.xavier_uniform_(self.user_embedding.weight)
        nn.init.xavier_uniform_(self.item_embedding.weight)
        nn.init.xavier_uniform_(self.item_image_categorical_embedding.weight)

    def forward(self, training=1):
        image_categorical_embedding = self.item_image_categorical_embedding.weight
        image_feats = self.image_trs(self.image_feats)
        item_image_style_embed = image_feats - image_categorical_embedding
        user_embed = self.user_embedding.weight
        item_embed = self.item_embedding.weight+item_image_style_embed

        return user_embed, item_embed


class GCN(nn.Module):
    def __init__(self, n_users, adj, embedding_dim):
        super().__init__()
        self.preference = nn.Embedding(n_users, embedding_dim) #nn.init.xavier_normal_(torch.rand((n_users, embedding_dim), requires_grad=True)).cuda()
        nn.init.xavier_uniform_(self.preference.weight)

        self.conv_embed_1 = nn.Linear(embedding_dim, embedding_dim)
        self.linear_layer1 = nn.Linear(embedding_dim, embedding_dim)
        self.g_layer1 =  nn.Linear(embedding_dim, embedding_dim)    

        self.conv_embed_2 = nn.Linear(embedding_dim, embedding_dim)
        self.linear_layer2 = nn.Linear(embedding_dim, embedding_dim)
        self.g_layer2 =  nn.Linear(embedding_dim, embedding_dim)

        self.conv_embed_3 = nn.Linear(embedding_dim, embedding_dim)
        self.linear_layer3 = nn.Linear(embedding_dim, embedding_dim)
        self.g_layer3 =  nn.Linear(embedding_dim, embedding_dim)
        self.adj = adj

    def forward(self, features, id_embedding):

        x = torch.cat((self.preference.weight, features),dim=0) 
        x = F.normalize(x).cuda()

        h = F.leaky_relu(torch.mm(self.adj, self.conv_embed_1(x)))
        x_hat = F.leaky_relu(self.linear_layer1(x)) + id_embedding 
        x = F.leaky_relu(self.g_layer1(h)+x_hat)

        h = F.leaky_relu(torch.mm(self.adj, self.conv_embed_2(x)))
        x_hat = F.leaky_relu(self.linear_layer2(x)) + id_embedding 
        x =  F.leaky_relu(self.g_layer2(h)+x_hat)

        h = F.leaky_relu(torch.mm(self.adj, self.conv_embed_3(x)))
        x_hat = F.leaky_relu(self.linear_layer3(x)) + id_embedding 
        x =  F.leaky_relu(self.g_layer3(h)+x_hat)

        return x

class MMGCN(nn.Module):
    def __init__(self, n_users, n_items, embedding_dim, weight_size, dropout_list, image_feats=None, adj=None, edge_index=None):
        super().__init__()
        self.n_users = n_users
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.user_embedding = nn.Embedding(n_users, embedding_dim)
        self.item_embedding = nn.Embedding(n_items, embedding_dim)

        nn.init.xavier_uniform_(self.user_embedding.weight)
        nn.init.xavier_uniform_(self.item_embedding.weight)

        self.image_feats = torch.Tensor(image_feats).cuda()

        self.image_trs = nn.Linear(image_feats.shape[1], embedding_dim)

        self.v_gcn = GCN(n_users, adj, embedding_dim)


    def forward(self, training=1):
        image_feats = self.image_trs(self.image_feats)
        id_embedding = torch.cat([self.user_embedding.weight, self.item_embedding.weight], dim=0)
        v_rep = self.v_gcn(image_feats, id_embedding)
        rep = v_rep
        user_rep = rep[:self.n_users]
        item_rep = rep[self.n_users:]
        return user_rep, item_rep


