import torch
from torch import nn
import higher
import numpy as np
import dgl
from dgl.nn.pytorch import GraphConv

from nonneg_sgd import *
from util import *
from flow import *
from gcn import *
from evaluation import *
from tikhonov import *
from bilevel_mlp import *

def get_test_feat_ids_and_priors(G, train, edge_map, priors, index):
    '''
        Extracts feature representation for GNNLearnFlow
    '''
    feat_ids = np.zeros(G.number_of_edges()-len(train),dtype=np.int16)
    prior_flows = np.zeros((G.number_of_edges()-len(train), 1))


    for e in G.edges():
        if e not in train:
            idx = edge_map[e]
            i = index[e]
            feat_ids[i] = idx
            prior_flows[i,0] = priors[e]

    use_cuda = torch.cuda.is_available()
    
    if use_cuda:
        return torch.tensor(feat_ids, device='cuda:0', dtype=torch.long), torch.tensor(prior_flows, device='cuda:0')
    else:
        return torch.tensor(feat_ids, dtype=torch.long), torch.tensor(prior_flows)
        
class GNNLearnFlow(nn.Module):
    '''
        Learns an GNN whose output is the optimal regularizer
        for a flow estimation problem using bilevel optimization (higher library)
    '''
    def __init__(self, G, dgl_G, priors, lamb, net, n_folds, edge_map, inner_n_iter_train, inner_n_iter_pred, outer_n_iter, inner_lr, outer_lr, nonneg=False, early_stop=10):
        super(GNNLearnFlow, self).__init__()
        
        self.early_stop = early_stop
        self.inner_n_iter_train = inner_n_iter_train
        self.inner_n_iter_pred = inner_n_iter_pred
        self.outer_n_iter = outer_n_iter
        self.inner_lr = inner_lr
        self.outer_lr = outer_lr
        
        self.G = G
        self.dgl_G = dgl_G
        self.nonneg = nonneg
        self.net = net
        self.edge_map = edge_map
        self.n_folds = n_folds
        self.priors = priors
        self.lamb = lamb
        
        self.use_cuda = torch.cuda.is_available()
        
        if self.use_cuda:
            self.cuda()
                
    def forward(self, A, b, x_init, reg_vec, x_prior, mapp, int_test_flows, verbose=False):
        '''
            Inner problem
        '''
        x = x_init.clone().detach().requires_grad_(True)

        if self.nonneg:
            if self.use_cuda:
                inner_opt = higher.get_diff_optim(NONNEGSGD([x], lr=self.inner_lr),[x], device='cuda:0')
            else:
                inner_opt = higher.get_diff_optim(NONNEGSGD([x], lr=self.inner_lr),[x])
        else:
            if self.use_cuda:
                inner_opt = higher.get_diff_optim(torch.optim.SGD([x], lr=self.inner_lr),[x], device='cuda:0')
            else:
                inner_opt = higher.get_diff_optim(torch.optim.SGD([x], lr=self.inner_lr),[x])
        
        self.tk = Tikhonov(A, b, 0, 0., self.nonneg)
        loss_func = nn.MSELoss()
        losses = []
        
        for epoch in range(self.inner_n_iter_train):
            tik_loss = self.tk(x, reg_vec, x_prior)  #tikhonov loss
            x, = inner_opt.step(tik_loss, params=[x])
            valid_loss = loss_func(torch.sparse.mm(mapp, x), int_test_flows) #validation loss
            losses.append(valid_loss.item())

            if epoch % 100 == 0 and verbose is True:
                print("epoch: ", epoch," inner loss = ", valid_loss.item())
            
            if epoch > self.early_stop and losses[-1] > np.mean(losses[-(self.early_stop+1):-1]):
                break
        
        return x, valid_loss
            
    def train(self, train_flows, valid_flows, verbose=False):
        '''
            Outer problem
        '''        
        int_folds = generate_folds({**train_flows, **valid_flows}, self.n_folds) #No extra validation
       
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.outer_lr)
        train_losses = []
        
        X = []

        if self.use_cuda:
            for (int_train,int_test) in int_folds:
                X.append(torch.zeros((self.G.number_of_edges()-len(int_train),1), dtype=torch.float, device='cuda:0'))
        else:
            for (int_train,int_test) in int_folds:
                X.append(torch.zeros((self.G.number_of_edges()-len(int_train),1), dtype=torch.float))

        for epoch in range(self.outer_n_iter):
            self.optimizer.zero_grad()
            
            if self.use_cuda:
                train_loss = torch.zeros(1, device='cuda:0')
            else:
                train_loss = torch.zeros(1)
            
            f = 0
            for (int_train,int_test) in int_folds:
                A, b, index = lsq_matrix_flow(self.G, int_train)
                int_test_flows, mapp = get_fold_flow_data(self.G, int_train, int_test)
                feat_ids, prior_flows = get_test_feat_ids_and_priors(self.G, int_train, self.edge_map, self.priors, index)
                
                #GCN
                all_reg = self.lamb * self.net.forward(self.dgl_G, self.dgl_G.ndata['feat'])
                reg_vec = all_reg[feat_ids]
                x, loss = self(A, b, X[f], reg_vec, prior_flows, mapp, int_test_flows, verbose)  #forward
                X[f] = x.clone().detach()
                train_loss = train_loss + loss
                loss.backward()
                self.optimizer.step()

                f = f + 1
            
            train_losses.append(train_loss.item())

            print("epoch: ", epoch, " outer train loss = ", train_loss.item())
                
            if epoch > self.early_stop and train_losses[-1] > np.mean(train_losses[-(self.early_stop+1):-1]):
                if verbose is True:
                    print("Early stopping...")
                break
                
        A, b, self.index = lsq_matrix_flow(self.G, {**train_flows, **valid_flows})
        feat_ids, prior_flows = get_test_feat_ids_and_priors(self.G, {**train_flows, **valid_flows}, self.edge_map, self.priors, self.index)
        all_reg = self.lamb * self.net.forward(self.dgl_G, self.dgl_G.ndata['feat']).detach()
        reg_vec = all_reg[feat_ids]
        x_init = initialize_flows(A.shape[1])
        
        self.tk = Tikhonov(A, b, self.inner_n_iter_pred, self.inner_lr, self.nonneg)
        self.tk.train(reg_vec, x_init, prior_flows, verbose=False)
