import copy
import random
from functools import wraps

import torch
from torch import nn
import torch.nn.functional as F

#from kornia import augmentation as augs
#from kornia import filters, color
from utils import drop_adj, drop_feature
from graph.argparser import args
# helper functions

def default(val, def_val):
    return def_val if val is None else val

def flatten(t):
    return t.reshape(t.shape[0], -1)

def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

# exponential moving average

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

# MLP class for projector and predictor

class MLP(nn.Module):
    def __init__(self, dim, hidden_size, projection_size): #4096):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_size),
            #TODO nn.BatchNorm1d(hidden_size),
            # nn.ReLU(inplace=True),
            nn.PReLU(),
            nn.Linear(hidden_size, projection_size)
        )

    def forward(self, x):
        return self.net(x)

# main class
class IGSD(nn.Module):
    def __init__(self, net, graph_size, hidden_layer = -2, projection_size = 256, projection_hidden_size = 4096, augment_fn = None, moving_average_decay = 0.99):
        super().__init__()
        self.online_encoder = net #NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)
        self.online_projector = MLP(projection_size, projection_size, projection_hidden_size)
        predict_size = args.predict_size
        self.predictor = MLP(projection_hidden_size, predict_size, projection_hidden_size) # predict_size)

    # adj - [B,N,N], feat - [B,N,feat_dim]
    def graph_aug(self, adj, feat):
        return drop_adj(adj, args.drop_prob), drop_feature(feat, args.drop_prob)

    @singleton('target_encoder')
    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        return target_encoder

    def reset_moving_average(self):
        del self.target_encoder
        self.target_encoder = None

    def update_moving_average(self):
        assert self.target_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)

    def loss_fn(self, x, y):
        x = F.normalize(x, dim=-1, p=2)
        y = F.normalize(y, dim=-1, p=2)
        return 2 - 2 * (x * y).sum(dim=-1)

    def forward(self, adj, feat, diff=None, mask=None):
        if args.aug_type == 'diff':
            online_proj_one, _ = self.online_encoder(adj, feat)
            online_proj_two, _ = self.online_encoder(diff, feat)

            online_pred_one = self.online_projector(online_proj_one) #self.online_predictor(online_proj_one)
            online_pred_two = self.online_projector(online_proj_two) #self.online_predictor(online_proj_two)

            online_pred_one = self.predictor(online_pred_one) #self.online_predictor(online_proj_one)
            online_pred_two = self.predictor(online_pred_two) #self.online_predictor(online_proj_two)

            with torch.no_grad():
                target_encoder = self._get_target_encoder()
                target_proj_one, _ = target_encoder(adj, feat)
                target_proj_two, _ = target_encoder(diff, feat)

                target_pred_one = self.online_projector(target_proj_one)  # self.online_predictor(online_proj_one)
                target_pred_two = self.online_projector(target_proj_two)  # self.online_predictor(online_proj_two)
        else:
            adj_one, feat_one = self.graph_aug(adj,feat)
            adj_two, feat_two = self.graph_aug(adj,feat)
            online_proj_one, _ = self.online_encoder(adj_one, feat_one)
            online_proj_two, _ = self.online_encoder(adj_two, feat_two)
            with torch.no_grad():
                target_encoder = self._get_target_encoder()
                target_proj_one, _ = target_encoder(adj_one, feat_one)
                target_proj_two, _ = target_encoder(adj_two, feat_two)

        loss_one =  self.loss_fn(online_pred_one, target_pred_two.detach()) #self.loss_fn(online_pred_one, target_proj_two.detach())
        loss_two =  self.loss_fn(online_pred_two, target_pred_one.detach()) #self.loss_fn(online_pred_two, target_proj_one.detach())

        loss = loss_one + loss_two # [bs, num_node]
        return loss.mean()

    def neg_loss(self, adj, feat, diff=None, mask=None):
        size = adj.shape[0]
        #idx = torch.randperm(size)
        #idx2 = torch.randperm(size)
        '''
        if args.aug_type == 'diff':
            online_proj_one, _ = self.online_encoder(adj, feat)
            online_proj_two, _ = self.online_encoder(diff, feat)

            online_pred_one = self.online_predictor(online_proj_one)
            online_pred_two = self.online_predictor(online_proj_two)

            with torch.no_grad():
                target_encoder = self._get_target_encoder()
                target_proj_one, _ = target_encoder(adj, feat)
                target_proj_two, _ = target_encoder(diff, feat)
        else:
            #TODO
            adj_one, feat_one = self.graph_aug(adj,feat)
            adj_two, feat_two = self.graph_aug(adj,feat)
            online_proj_one, _ = self.online_encoder(adj_one, feat_one)
            online_proj_two, _ = self.online_encoder(adj_two, feat_two)
            with torch.no_grad():
                target_encoder = self._get_target_encoder()
                target_proj_one, _ = target_encoder(adj_one, feat_one)
                target_proj_two, _ = target_encoder(adj_two, feat_two)
        '''
        #online_pred_one = online_pred_one.unsqueeze(dim=1)
        #online_pred_two = online_pred_two.unsqueeze(dim=0)
        #loss_one = self.loss_fn(online_pred_one, target_proj_two.detach())
        #loss_two = self.loss_fn(online_pred_two, target_proj_one.detach())
        '''
        x = F.normalize(online_pred_one, dim=-1, p=3)
        y = F.normalize(online_pred_two, dim=-1, p=3)
        loss1 = 2 - 2 * (x * y).sum(dim=-1)
        loss2 = 2 - 2 * (y * x).sum(dim=-1)
        '''

        if args.aug_type == 'diff':
            online_proj_one, _ = self.online_encoder(adj, feat)
            online_proj_two, _ = self.online_encoder(diff, feat)

            online_pred_one = self.online_projector(online_proj_one) #self.online_predictor(online_proj_one)
            online_pred_two = self.online_projector(online_proj_two) #self.online_predictor(online_proj_two)

            online_pred_one = self.predictor(online_pred_one) #self.online_predictor(online_proj_one)
            online_pred_two = self.predictor(online_pred_two) #self.online_predictor(online_proj_two)

            with torch.no_grad():
                target_encoder = self._get_target_encoder()
                target_proj_one, _ = target_encoder(adj, feat)
                target_proj_two, _ = target_encoder(diff, feat)

                target_pred_one = self.online_projector(target_proj_one)  # self.online_predictor(online_proj_one)
                target_pred_two = self.online_projector(target_proj_two)  # self.online_predictor(online_proj_two)
        else:
            print("TODO")
        '''
        if size > 100:
            idx = int(size/2)
            loss1 = (self.loss_fn(online_pred_one[:idx].unsqueeze(dim=0), target_proj_two[:idx].unsqueeze(dim=1)) / args.alpha).exp().sum(-1)
            loss2 = (self.loss_fn(online_pred_one[:idx].unsqueeze(dim=0), target_proj_one[:idx].unsqueeze(dim=1)) / args.alpha).exp().sum(-1)

            loss1 += (self.loss_fn(online_pred_one[idx:].unsqueeze(dim=0), target_proj_two[idx:].unsqueeze(dim=1)) / args.alpha).exp().sum(-1)
            loss2 += (self.loss_fn(online_pred_one[idx:].unsqueeze(dim=0), target_proj_one[idx:].unsqueeze(dim=1)) / args.alpha).exp().sum(-1)

            loss1 += (self.loss_fn(online_pred_one[:idx].unsqueeze(dim=0), target_proj_two[idx:].unsqueeze(dim=1)) / args.alpha).exp().sum(-1)
            loss2 += (self.loss_fn(online_pred_one[:idx].unsqueeze(dim=0), target_proj_one[idx:].unsqueeze(dim=1)) / args.alpha).exp().sum(-1)

            loss1 += (self.loss_fn(online_pred_one[idx:].unsqueeze(dim=0), target_proj_two[:idx].unsqueeze(dim=1)) / args.alpha).exp().sum(-1)
            loss2 += (self.loss_fn(online_pred_one[idx:].unsqueeze(dim=0), target_proj_one[:idx].unsqueeze(dim=1)) / args.alpha).exp().sum(-1)

            loss = (loss1+loss2).log()
        else:
        '''
        #loss = loss1.max(dim=0)[0] + loss2.max(dim=0)[0] # [bs, num_node]
        '''
        loss1 = (self.loss_fn(online_pred_one.unsqueeze(dim=0), target_proj_two.unsqueeze(dim=1)) / args.alpha).exp().sum(-1)
        loss2 = (self.loss_fn(online_pred_two.unsqueeze(dim=0), target_proj_one.unsqueeze(dim=1)) / args.alpha).exp().sum(-1)
        '''
        loss1 = (self.loss_fn(online_pred_one.unsqueeze(dim=0), target_pred_two.unsqueeze(dim=1)) / args.alpha).exp().sum(-1)
        loss2 = (self.loss_fn(online_pred_two.unsqueeze(dim=0), target_pred_one.unsqueeze(dim=1)) / args.alpha).exp().sum(-1)

        loss  = (loss1+loss2).log()

        #args.neg_prob = 0.5
        if args.neg_prob < 1:
            prob = torch.rand(loss.shape, device=args.device)
            loss = torch.masked_select(loss, prob.ge(args.neg_prob))

        #loss = loss1.max(dim=0)[0] + loss2.max(dim=0)[0] # [bs, num_node]
        return (loss - (torch.eye(loss.shape[0], device=args.device)*loss)).mean()

    def embed(self, adj, diff, feat, mask=None):
        #online_proj, _ = self.online_encoder(adj, feat)
        #online_pred = self.online_predictor(online_proj)
        #online_pred = online_pred.sum(1) # [bs, num_node, feat_dim] -> [bs, feat_dim]
        #return online_pred.detach()

        online_l_one, _ = self.online_encoder(adj, feat)
        online_l_two, _ = self.online_encoder(diff, feat)

        online_proj_one = online_l_one.sum(1)
        online_proj_two = online_l_two.sum(1)

        return (online_proj_one + online_proj_two).detach()