import sys
sys.path.append('../')
sys.path.append('../../')

from typing import List, Any, Dict, Iterable, Optional
import numpy as np
import os
from pathlib import Path
from tqdm import tqdm
import pickle as pkl
import itertools

import gurobipy as gp

try:
    from gurobi_onboarder import init_gurobi
    gurobi_venv, GUROBI_FOUND = init_gurobi.initialize_gurobi()
except:
    gurobi_venv = gp.Env(empty=True)
gurobi_venv.setParam("OutputFlag", 0)



from forge import Forge

from utils import generate_mip_graph
from sklearn.decomposition import PCA
import torch
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Collector:
    """
    Data processor
    """

    def __init__(self):
        super().__init__()

    def get_triplets(self, input_path, gnn_model_path, output_file, return_dict=False) -> Optional[Dict[
        str, Dict[str, Iterable]]]:

        model = Forge()
        model.load_model(gnn_model_path)

        mips_to_triplet = {}
        for inst in tqdm(os.listdir(input_path)):

            g, features, num_cons, num_vars = generate_mip_graph(os.path.join(input_path, inst), graph_features=False)

            # Some instances are too big to fit in the GPU, skip 
            if g.num_nodes() > 30000:
                continue

            # Forward pass through trained GNN
            h_list, logits, loss, distances, codebook_ = model(g.to(device), features.to(device), num_cons, num_vars)

            # Get Input Features for each node
            in_feats = g.ndata['feat']

            # For each variable, generate an embedding vector by
            # 1. Compute softmax of the distance matrix 
            # 2. Compute 128 dimensional PCA embedding for ease of computation later
            mipe = PCA(n_components=128).fit_transform(F.softmax(distances, dim=1).detach().cpu().numpy())

            # These features are only going to be used for distance computation in picking negative triplets
            feats = torch.hstack([torch.Tensor(mipe), in_feats])[num_cons:, :]

            # Also consider embedding space distances 
            feats_log = logits[num_cons:].detach().cpu().numpy()

            # Read MIP Instance
            # Solve only for 5 minutes 
            gurobi_venv.setParam("TimeLimit", 300)
            gurobi_venv.setParam("Threads", 36)
            m = gp.read(os.path.join(input_path, inst), env=gurobi_venv)

            # Generate multiple solutions
            m.Params.PoolSearchMode = 1
            m.Params.PoolSolutions = 5
            m.optimize()

            # Get number of solutions each variable appears in
            # {variable index : num solutions variable appears in}
            num_vars = len(m.getVars())
            var_sol_count = {idx: 0 for idx in range(num_vars)}
            for i in range(m.SolCount):
                # Set solution number
                m.Params.SolutionNumber = i
                # Read solution
                for v in np.where(np.array(m.Xn) > 0)[0]:
                    var_sol_count[v] = var_sol_count[v] + 1

            # Get all variables that occur in 'x' amount of solutions
            # {number of solutions : variables appearing in these solutions}
            vars_by_count = {i: [] for i in range(m.SolCount + 1)}
            for v in var_sol_count.keys():
                vars_by_count[var_sol_count[v]].append(v)

            # Calculate distance matrix for variables based on input features
            # Matrix initialized to -1
            var_distance = np.zeros((num_vars, num_vars)) - 1
            var_distance_log = np.zeros((num_vars, num_vars)) - 1

            # Set of all variables that appear in atleast one solution
            positives = list(set(list(range(num_vars))).difference(set(vars_by_count[0])))

            for i in positives:
                for j in positives:
                    var_distance[i, j] = np.inf
                    var_distance_log[i, j] = np.inf

            for r in range(num_vars):
                var_distance[r][r] = np.inf
                var_distance_log[r][r] = np.inf
                for c in range(num_vars):
                    if var_distance[r][c] == -1 and var_distance[r][c] != np.inf:
                        var_distance[r][c] = var_distance[c][r] = np.linalg.norm(feats[r, :] - feats[c, :])
                    if var_distance_log[r][c] == -1 and var_distance_log[r][c] != np.inf:
                        var_distance_log[r][c] = var_distance_log[c][r] = np.linalg.norm(
                            feats_log[r, :] - feats_log[c, :])

            train_set = []
            # Minimize distance between nodes in same number of solutions - these are positive and anchor nodes for the triplet loss
            for s in range(1, m.SolCount + 1):
                train_set += list(itertools.combinations(vars_by_count[s], 2))

            # Maximize distance between nodes picked and closest node never picked
            # Get top-3 closest negatives
            closest_negatives = np.argpartition(var_distance, kth=2, axis=1)[:, :3]
            closest_negatives_log = np.argpartition(var_distance_log, kth=2, axis=1)[:, :3]

            triplets = []
            for i in train_set:
                triplets.append([i[0], i[1], closest_negatives[i[0], 0]])
                triplets.append([i[0], i[1], closest_negatives[i[0], 1]])
                triplets.append([i[0], i[1], closest_negatives[i[0], 2]])

                triplets.append([i[1], i[0], closest_negatives[i[1], 0]])
                triplets.append([i[1], i[0], closest_negatives[i[1], 1]])
                triplets.append([i[1], i[0], closest_negatives[i[1], 2]])

                triplets.append([i[0], i[1], closest_negatives_log[i[0], 0]])
                triplets.append([i[0], i[1], closest_negatives_log[i[0], 1]])
                triplets.append([i[0], i[1], closest_negatives_log[i[0], 2]])

            triplets = np.array(triplets)

            # Set target array based on all positives
            y_true = torch.zeros(num_vars)

            # Variable nodes appearing in ANY solution is set as 1 
            y_true[positives] = 1
            y_true = y_true.reshape(-1, 1)
            y_true = y_true.to(device)

            mips_to_triplet[os.path.join(input_path, inst)] = {'triplets': triplets,
                                                               'y_true': y_true, }

        with open(output_file, 'wb') as file:
            pkl.dump(mips_to_triplet, file)

        if return_dict:
            return mips_to_triplet

    def get_cut_ratios(self, input_path, output_file, return_dict=False) -> Dict[str, Dict[str, Any]]:

        mips_to_gaps = {}
        for i_num, inst in enumerate(os.listdir(input_path)):
            gurobi_venv.setParam("OutputFlag", 0)
            gurobi_venv.setParam("TimeLimit", 120)
            m = gp.read(os.path.join(input_path, inst), env=gurobi_venv)
            relax = m.copy().relax()
            m.optimize()
            relax.optimize()

            ratio = min(m.objVal, relax.objVal) / max(m.objVal, relax.objVal)

            mips_to_gaps[os.path.join(input_path, inst)] = {'ratio': ratio,
                                                            'mip_sol': m.Xn,
                                                            'mip_obj': m.objVal,
                                                            'lp_obj': relax.objVal,
                                                            }

            with open(output_file, 'wb') as file:
                pkl.dump(mips_to_gaps, file)

            print("\rInstance : ", i_num, "| Ratio : ", ratio, end='')

        if return_dict:
            return mips_to_gaps

    def _util_method(self):
        pass
