import sys
sys.path.append('../')
sys.path.append('../../')

import time
import dgl
import pickle as pkl
from dgl.nn import SAGEConv

from vqgraph.vq import VectorQuantize
from pathlib import Path
from sklearn.decomposition import PCA
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.onnx.symbolic_opset12 import dropout

import numpy as np
import scipy.sparse as sp
from sklearn.metrics.cluster import contingency_matrix

import gurobipy as gp
from gurobipy import GRB

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def params(model):
    """
    Return number of parameters in a torch model
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def normalize(mx):
    """
    Row-normalize sparse matrix
    """
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.0
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


def normalize_adj(adj):
    adj = normalize(adj + sp.eye(adj.shape[0]))
    return adj


def get_dgl_graph(adj, feats, coeff_adj):
    """
    Called by generate_mip_graph()
    Returns a DGL object
    """

    coeff_adj_coo = coeff_adj.tocoo()
    edge_weights = coeff_adj[coeff_adj_coo.row, coeff_adj_coo.col].flatten()
    edge_weights = (edge_weights - edge_weights.min()) / (edge_weights.max() - edge_weights.min() + 1e-6)
    edge_weights += 1e-4

    features = torch.FloatTensor(np.array(feats))
    # adj = normalize_adj(adj)

    if adj is False:
        return False

    adj_sp = adj.tocoo()
    g = dgl.graph((adj_sp.row, adj_sp.col))
    g.ndata["feat"] = features
    g.edata["weight"] = torch.FloatTensor(edge_weights).T
    g.num_nodes() == g.ndata['feat'].shape[0]
    return g


def generate_mip_graph(path, debug=False, graph_features=True, gurobi_object=False):
    """
    Given an MPS or LP file or a Gurobi object, return the DGL graph object and the node feature matrix
    #TODO explain params what's graph_feat? what does gurobi_object do?
    #TODO what's returned?
    """

    # Read MPS file
    env = gp.Env(empty=True)
    env.setParam("OutputFlag", 0)
    env.start()

    if gurobi_object is False:
        mip_model = gp.read(path, env)
    else:
        mip_model = path
    to_remove = [v for v in mip_model.getVars() if not mip_model.getCol(v).size()]
    mip_model.remove(to_remove)
    mip_model.update()

    # Get constraint matrix
    coefficient_matrix = mip_model.getA().todense()
    A = (coefficient_matrix != 0).astype(int)

    # Get objective equation
    # obj = np.array([x.Obj for x in mip_model.getVars()])

    # Create feature vectors for each variable
    variables = mip_model.getVars()
    var_feature_vector = []
    for v in variables:
        feat = [
            v.VType == gp.GRB.CONTINUOUS,
            v.VType == gp.GRB.BINARY,
            v.VType == gp.GRB.INTEGER,
            v.Obj,
            v.LB > -gp.GRB.INFINITY,
            v.UB <= gp.GRB.INFINITY,
        ]
        feat = [float(x) for x in feat]
        var_feature_vector.append(feat)
    # Create feature vectors for each constraint
    constraints = mip_model.getConstrs()
    operators = mip_model.Sense
    constraints_feature_vector = []
    for c, o in list(zip(constraints, operators)):
        feat = [
            o == '=',
            o == '<',
            o == '>',
            c.RHS,
        ]
        feat = [float(x) for x in feat]
        constraints_feature_vector.append(feat)

    num_cons = len(constraints)
    num_vars = len(variables)
    # return mip_model

    if debug:
        print("Creating Adjacency")
        s = time.time()
    # Create graph
    adj = np.block([[np.zeros((num_cons, num_cons)), A],
                    [A.T, np.zeros((num_vars, num_vars))]])
    coeff_adj = np.block([[np.zeros((num_cons, num_cons)), coefficient_matrix],
                          [coefficient_matrix.T, np.zeros((num_vars, num_vars))]])
    if debug:
        print("Graph Created in ", time.time() - s, "seconds")

    if graph_features:
        if debug:
            print("Creating IG")
            s = time.time()
        graph = igraph.Graph.Adjacency((adj > 0), mode='undirected')

        if debug:
            print("IG Created in ", time.time() - s, "seconds")

        # Graph based structural features
        if debug:
            print("Computing Sense Features")
            s = time.time()
        sf_names, sf = get_sense_features(graph)
        if debug:
            print("Sense Features Computed in ", time.time() - s, "seconds")

    if debug:
        print("Computing Feature Vectors")
        s = time.time()
    # Create feature matrix
    constraints_feature_vector = np.array(constraints_feature_vector)
    var_feature_vector = np.array(var_feature_vector)

    # Pad with zeros for equal shapes
    constraints_feature_matrix = np.hstack(
        [constraints_feature_vector, np.zeros((num_cons, var_feature_vector.shape[1]))])
    var_feature_matrix = np.hstack([np.zeros((num_vars, constraints_feature_vector.shape[1])), var_feature_vector])

    # Stack up into one feature matrix
    features = np.vstack([constraints_feature_matrix, var_feature_matrix])
    if graph_features:
        features = np.hstack([features, sf])

    # Column normalize
    features = (features - np.min(features, axis=0)) / (np.max(features, axis=0) - np.min(features, axis=0) + 1e-9)
    features[np.isnan(features)] = 0
    if debug:
        print("Feature Vectors Computed in ", time.time() - s, "seconds")

    # Create DGL graph
    if debug:
        print("Creating DGL Graph")
        s = time.time()
    g = get_dgl_graph(sp.csr_matrix(adj), features, sp.csr_matrix(coeff_adj))
    if g is False:
        return g, None, None, None
    features = g.ndata["feat"]
    if debug:
        print("DGL Graph Created in ", time.time() - s, "seconds")

    return g, features, num_cons, num_vars


class Callback():
    """
    Callback function to record primal gaps during Gurobi optimization
    """

    def __init__(self):
        self.primal_gap = {}
        self.last_call = np.round(time.time(), 1)
        self.start_time = np.round(time.time(), 1)

    def __call__(self, model, where):
        if where == GRB.Callback.MIP:

            best_obj = model.cbGet(GRB.Callback.MIP_OBJBST)
            best_bound = model.cbGet(GRB.Callback.MIP_OBJBND)
            t = model.cbGet(GRB.Callback.RUNTIME)
            current_time = np.round(time.time(), 1)
            if current_time - self.last_call >= 0.1:
                time_key = np.round(current_time - self.start_time, 1)
                self.primal_gap[time_key] = 100 * (np.abs(best_bound - best_obj) / np.abs(best_obj))
                self.last_call = current_time


def clean_pg_dict(in_dict, timesteps):
    """
    Function to clean primal gap dictionaries returned by the callback functions
    and ensure they're all the same number of timesteps and monotonically decreasing 
    """
    cleaned_dict = []
    for i in in_dict:

        d = {x: [] for x in range(timesteps)}
        min_ts = int(list(i.keys())[0])
        for v in i:
            d[int(v)].append(i[v])
        max_ts = int(v)

        for key in d:
            if d[key] == []:
                if key > max_ts:
                    d[key].append(0)
                elif key < min_ts:
                    d[key] = d[min_ts]
                else:
                    d[key] = d[key - 1]

        cleaned_dict.append(d)

    for idx in range(len(cleaned_dict)):
        for key in cleaned_dict[idx]:
            cleaned_dict[idx][key] = np.mean(cleaned_dict[idx][key])

    return cleaned_dict


def purity_score(y_true, y_pred):
    """
    Used for Clustering Accuracy
    """

    # Compute contingency matrix (also called confusion matrix)
    con_matrix = contingency_matrix(y_true, y_pred)

    # Return purity
    return np.sum(np.amax(con_matrix, axis=0)) / np.sum(con_matrix)


def copy_params(old_model, new_model):
    small_state_dict = old_model.state_dict()
    large_state_dict = new_model.state_dict()

    for name, param in small_state_dict.items():
        if name in large_state_dict:
            large_state_dict[name].copy_(param)
