import logging
import copy
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.float_embedding import FloatEmbedding

with open('inversion_config.json', 'r') as f:
    inv_config = json.load(f)

target_label = inv_config['target_label']

class LookupEmbedding(torch.nn.Module):
    def __init__(self, uid_all, iid_all, emb_dim):
        super().__init__()
        self.uid_embedding = torch.nn.Embedding(uid_all, emb_dim)
        self.iid_embedding = torch.nn.Embedding(iid_all, emb_dim)

    def forward(self, x):
        uid_emb = self.uid_embedding(x[:, 0])
        iid_emb = self.iid_embedding(x[:, 1])
        return uid_emb, iid_emb

class FloatLookupEmbedding(torch.nn.Module):

    def __init__(self, uid_all, iid_all, emb_dim):
        super().__init__()
        self.uid_embedding = torch.nn.Embedding(uid_all, emb_dim)
        self.iid_embedding = FloatEmbedding(iid_all, emb_dim)

    def forward(self, x, stage=None):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if x.dtype is torch.float32:
            uid_idx = x[:, 0].type(torch.LongTensor).to(device)
            uid_emb = self.uid_embedding(uid_idx)
            iid_emb = self.iid_embedding(x[:, 1:])
            iid_emb = torch.sum(iid_emb, dim=1).unsqueeze(1)
            if stage is None:
                emb = torch.cat([uid_emb, iid_emb], dim=1)
                return emb
            if stage is 'save_inversed_iid':
                a = uid_idx.unsqueeze(2)
                emb = torch.cat([uid_idx.unsqueeze(2), iid_emb], dim=2)
                return emb
        else:
            uid_emb = self.uid_embedding(x[:, 0].unsqueeze(1))
            iid_emb = self.iid_embedding(x[:, 1].unsqueeze(1))
            emb = torch.cat([uid_emb, iid_emb], dim=1)
            return emb


class MFModel(torch.nn.Module):
    def __init__(self, uid_all, iid_all, emb_dim):
        super().__init__()
        self.embedding = LookupEmbedding(uid_all, iid_all, emb_dim)

    def forward(self, x):
        uid_emb, iid_emb = self.embedding.forward(x)
        emb = torch.sum(uid_emb * iid_emb, dim=1)
        return emb


class FRJVE_Model(torch.nn.Module):
    # TODO change
    def __init__(self, field_dims_src, field_dims_tgt, num_fields, emb_dim, topk):
        super().__init__()
        self.num_fields = num_fields
        self.emb_dim = emb_dim
        self.src_model = FloatLookupEmbedding(field_dims_src['uid_src'], field_dims_src['iid_src'], emb_dim)
        # TODO change
        self.tgt_model = FloatLookupEmbedding(field_dims_tgt['uid_tgt'], field_dims_tgt['iid_tgt'], emb_dim)
        # TODO change
        self.topk = topk
        self.rp_mapping = torch.nn.Linear(emb_dim, emb_dim, False)
        self.mapping = torch.nn.Linear(emb_dim, emb_dim, False)


    def forward(self, x, stage, inversion_stage=None):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # TODO change
        if stage in ['train_src', 'src_inversion']:
            emb = self.src_model.forward(x, inversion_stage)
            if inversion_stage is None:
                x = torch.sum(emb[:, 0, :] * emb[:, 1, :], dim=1)
                return x
            if inversion_stage == 'save_inversed_iid':
                return emb
        # TODO change
        elif stage in ['tgt_inversion']:
            emb = self.tgt_model.forward(x, inversion_stage)
            if inversion_stage is None:
                x = torch.sum(emb[:, 0, :] * emb[:, 1, :], dim=1)
                return x
            if inversion_stage == 'save_inversed_iid':
                return emb
        elif stage in ['train_tgt', 'test_tgt']:
            emb = self.tgt_model.forward(x)
            x = torch.sum(emb[:, 0, :] * emb[:, 1, :], dim=1)
            return x
        elif stage == 'train_map':
            src_emb = self.src_model.uid_embedding(x.unsqueeze(1)).squeeze()
            src_emb = self.mapping.forward(src_emb)
            tgt_emb = self.tgt_model.uid_embedding(x.unsqueeze(1)).squeeze()
            return src_emb, tgt_emb
        elif stage == 'test_map':
            uid_emb = self.mapping.forward(self.src_model.uid_embedding(x[:, 0].unsqueeze(1)).squeeze())
            emb = self.tgt_model.forward(x)
            emb[:, 0, :] = uid_emb
            x = torch.sum(emb[:, 0, :] * emb[:, 1, :], dim=1)
            return x
        elif stage == 'train_source_free':
            # input x organized with [uid, src_rate_pre, tgt_rate_pre]
            src_rate_pre = x[:, 1:self.emb_dim+1]
            tgt_rate_pre = x[:, self.emb_dim+1:]
            src_rate_pre = self.rp_mapping.forward(src_rate_pre)
            return tgt_rate_pre, src_rate_pre
        elif stage == 'test_source_free':
            # input x organized with [iid, src_rate_pre]
            # Our aim is 1: transfer src_rate_pre to tgt domain; 2: find topk users in tgt domain
            # 3: vote mechanism to get the rate
            iid = x[:, 0].type(torch.LongTensor).to(device)
            iid_emb = self.tgt_model.iid_embedding(iid.unsqueeze(1))
            src_rate_pre = x[:, 1:self.emb_dim+1]
            src_rate_pre = self.rp_mapping.forward(src_rate_pre)
            tgt_user_embedding = copy.deepcopy(self.tgt_model.uid_embedding.weight)
            predict_rate = torch.mm(tgt_user_embedding, src_rate_pre.T) - torch.tensor((target_label))
            predict_rate = torch.abs(predict_rate)
            sorted, indices = torch.sort(predict_rate, dim=0)
            topk = self.topk
            topk_indices = indices[:topk, :]
            topk_indices = topk_indices.T.reshape(-1, 1).squeeze()
            topk_uid_emb = self.tgt_model.uid_embedding(topk_indices).view(len(x), topk, self.emb_dim)
            voted_rating = torch.sum(topk_uid_emb * iid_emb, dim=2)
            mean_rating = torch.mean(voted_rating, dim=1)

            return mean_rating
