import networkx as nx
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import scipy
from scipy.optimize import leastsq
from scipy.sparse import *
import matplotlib.pyplot as plt
import dgl
from dgl.nn.pytorch import GraphConv
import higher
from torch.optim.sgd import SGD

from tikhonov import *
from util import *
from mlp import *
from gcn import *
from flow import *
from evaluation import *
from bilevel_mlp import *
from bilevel_gcn import *

class FlowEstimator(object):
    '''
        Generic class for flow estimation
    '''
    def __init__(self, G, features, params, verbose=False):
        self.G = G
        self.features = features
        self.params = params
        self.verbose = verbose

    def train(self, train_flows, valid_flows):
        raise NotImplementedError

    def predict(self, test_flows):
        raise NotImplementedError

class MinDiv(FlowEstimator):
    '''
        Divergence minimization
    '''
    def __init__(self, G, features, params, verbose=False):
        super(MinDiv, self).__init__(G, features, params, verbose)

    def train(self, train_flows, valid_flows):
        n_iter = self.params['n_iter']
        lr = self.params['lr']
        early_stop = self.params['early_stop']
        nonneg = self.params['nonneg']
        min_lamb = self.params['min_lamb']
        max_lamb = self.params['max_lamb']
        priors = self.params['priors']

        A, b, self.index = lsq_matrix_flow(self.G, train_flows)
        lamb, loss = gss(fun_opt_lamb, [self.G, self.index, valid_flows, A, b, n_iter, lr, nonneg], min_lamb, max_lamb)

        print("lamb = ", lamb)

        reg_vec = lamb * np.ones((A.shape[1],1))
        x_init = initialize_flows(A.shape[1], zeros=True)
        x_prior = get_prior(self.G, priors, train_flows, self.index)

        self.tk = Tikhonov(A, b, n_iter, lr, nonneg)
        self.tk.train(reg_vec, x_init, x_prior, verbose=self.verbose)

    def predict(self, test_flows):
        return get_dict_flows_from_tensor(self.index, self.tk.x, test_flows)


class MLPFlowPred(FlowEstimator):
    '''
        Flow estimator using MLP
    '''
    def __init__(self, G, features, params, verbose=False):
        super(MLPFlowPred, self).__init__(G, features, params, verbose)

    def train(self, train_flows, valid_flows):
        n_iter = self.params['n_iter']
        lr = self.params['lr']
        early_stop = self.params['early_stop']
        n_hidden = self.params['n_hidden']

        train_feat_tensor = get_tensor_features(self.G, self.features, train_flows)
        valid_feat_tensor = get_tensor_features(self.G, self.features, valid_flows)
        train_flows_tensor = get_tensor_flows(self.G, train_flows)
        valid_flows_tensor = get_tensor_flows(self.G, valid_flows)

        n_features = self.features[list(self.features.keys())[0]].shape[0]

        self.net = Net(n_features, n_hidden, n_iter, lr, early_stop, torch.nn.ReLU())

        self.net.train(train_feat_tensor, train_flows_tensor, valid_feat_tensor, valid_flows_tensor, verbose=self.verbose)

    def predict(self, test_flows):
        test_feat_tensor = get_tensor_features(self.G, self.features, test_flows)
        pred = self.net.forward(test_feat_tensor)

        return get_dict_flows(self.G, pred, test_flows)

class GCNFlowPred(FlowEstimator):
    '''
        Flow estimator using GCN
    '''
    def __init__(self, G, features, params, verbose=False):
        super(GCNFlowPred, self).__init__(G, features, params, verbose)
        
    def train(self, train_flows, valid_flows):
        n_iter = self.params['n_iter']
        lr = self.params['lr']
        early_stop = self.params['early_stop']
        n_hidden = self.params['n_hidden']
        nonneg = self.params['nonneg']
        
        self.G_dgl, self.edge_map = create_dgl_graph(self.G, self.features, nonneg)
        lamb_max = dgl.laplacian_lambda_max(self.G_dgl)
        
        train_feat_ids = get_feat_ids(self.G, train_flows, self.edge_map)
        valid_feat_ids = get_feat_ids(self.G, valid_flows, self.edge_map)
        train_flows_tensor = get_tensor_flows(self.G, train_flows)
        valid_flows_tensor = get_tensor_flows(self.G, valid_flows)
        
        n_features = self.features[list(self.features.keys())[0]].shape[0]
        
        self.net = GCN(n_features, n_hidden, n_iter, lr, lamb_max, early_stop, torch.nn.ReLU())
        
        self.net.train(self.G_dgl, self.edge_map, train_feat_ids, train_flows_tensor, valid_feat_ids, valid_flows_tensor, verbose=self.verbose)
        
    def predict(self, test_flows):
        test_feat_ids = get_feat_ids(self.G, test_flows, self.edge_map)
        outputs = self.net.forward(self.G_dgl, self.G_dgl.ndata['feat'])
        
        pred = outputs[test_feat_ids]
        
        return get_dict_flows(self.G, pred, test_flows)


class BilMLP(FlowEstimator):
    '''
        Flow estimator using bilevel optimization and MLP
    '''
    def __init__(self, G, features, params, verbose=False):
        super(BilMLP, self).__init__(G, features, params, verbose)
        
    def train(self, train_flows, valid_flows):
        inner_n_iter_train = self.params['inner_n_iter_train']
        inner_n_iter_pred = self.params['inner_n_iter_pred']
        outer_n_iter = self.params['outer_n_iter']
        inner_lr = self.params['inner_lr']
        outer_lr = self.params['outer_lr']
        early_stop = self.params['early_stop']
        nonneg = self.params['nonneg']
        n_hidden = self.params['n_hidden']
        n_folds = self.params['n_folds']
        priors = self.params['priors']
        lamb = self.params['lambda']
        
        n_features = self.features[list(self.features.keys())[0]].shape[0]
        
        net = Net(n_features, n_hidden, outer_n_iter, outer_lr, early_stop, torch.nn.ReLU())
        
        self.lf = MLPLearnFlow(self.G, self.features, priors, lamb, net, n_folds,\
                inner_n_iter_train, inner_n_iter_pred, outer_n_iter, inner_lr, outer_lr, nonneg=nonneg, early_stop=early_stop)
        
        self.lf.train(train_flows, valid_flows, verbose=self.verbose)
         
    def predict(self, test_flows):
        return get_dict_flows_from_tensor(self.lf.index, self.lf.tk.x, test_flows)    

class BilGCN(FlowEstimator):
    '''
        Flow estimator combining bilevel optimization
        and GCN
    '''
    def __init__(self, G, features, params, verbose=False):
        super(BilGCN, self).__init__(G, features, params, verbose)
        
    def train(self, train_flows, valid_flows):
        inner_n_iter_train = self.params['inner_n_iter_train']
        inner_n_iter_pred = self.params['inner_n_iter_pred']
        outer_n_iter = self.params['outer_n_iter']
        inner_lr = self.params['inner_lr']
        outer_lr = self.params['outer_lr']
        early_stop = self.params['early_stop']
        nonneg = self.params['nonneg']
        n_hidden = self.params['n_hidden']
        n_folds = self.params['n_folds']
        priors = self.params['priors']
        lamb = self.params['lambda']
        
        n_features = self.features[list(self.features.keys())[0]].shape[0]
        
        self.G_dgl, self.edge_map = create_dgl_graph(self.G, self.features, nonneg)
        lamb_max = dgl.laplacian_lambda_max(self.G_dgl)
        
        gcn = GCN(n_features, n_hidden, outer_n_iter, outer_lr, lamb_max, early_stop, torch.nn.ReLU())
        self.gcn_lf = GNNLearnFlow(self.G, self.G_dgl, priors, lamb, gcn, n_folds, self.edge_map,\
            inner_n_iter_train, inner_n_iter_pred, outer_n_iter, inner_lr, outer_lr, nonneg=nonneg, early_stop=early_stop)
        
        self.gcn_lf.train(train_flows, valid_flows, verbose=self.verbose)
         
    def predict(self, test_flows):
        return get_dict_flows_from_tensor(self.gcn_lf.index, self.gcn_lf.tk.x, test_flows)        
