import sys
sys.path.append('../')
sys.path.append('../../')

from typing import List, Any, Dict
import numpy as np
import os
from pathlib import Path
from tqdm import tqdm
import pickle as pkl

import gurobipy as gp

try:
    from gurobi_onboarder import init_gurobi
    gurobi_venv, GUROBI_FOUND = init_gurobi.initialize_gurobi()
except:
    gurobi_venv = gp.Env(empty=True)
gurobi_venv.setParam("OutputFlag", 0)



from utils import generate_mip_graph


class DataProcessor:
    """
    Data processor
    """

    def __init__(self):

        super().__init__()

    def get_instance_dict(self, input_path, output_file, perturb=[], return_dict=False) -> Dict[str, List[Any]]:

        """
            input_path : path of directory with MS / LP files
            output_file : path and file name to save processed list of instances
            perturb : list of perturbation ratios between 0-1 to create new instances by randomly dropping constraints
            return_dict : Flag - returns the dictionary if true else only writes to file 
        """

        all_files = os.listdir(input_path)
        files_and_sizes = [os.path.join(input_path, path) for path in all_files]
        files_and_sizes = [x for x in files_and_sizes if 'mps' in x or 'lp' in x]
        sorted_instances = sorted(files_and_sizes, key=os.path.getsize)

        mip_to_dgl = {}

        for idx in tqdm(range(len(sorted_instances))):

            # Generate DGL object
            g, features, num_cons, num_vars = generate_mip_graph(sorted_instances[idx], graph_features=False)
            mip_to_dgl[sorted_instances[idx]] = [g, features, num_cons, num_vars]

            if perturb != []:
                for ratio in perturb:
                    m = gp.read(sorted_instances[idx], env=gurobi_venv)
                    cons = m.getConstrs()

                    # Randomly remove a given ratio of constrainsts to create a new MIP instance
                    cons_remove_ = np.random.choice(cons, int(len(cons) * ratio), replace=False)
                    for c in cons_remove_:
                        m.remove(c)
                    m.update()
                    g, features, num_cons, num_vars = generate_mip_graph(m, graph_features=False, gurobi_object=True)
                    mip_to_dgl[sorted_instances[idx] + '_perturbed_' + str(ratio)] = [g, features, num_cons, num_vars]

        with open(output_file, 'wb') as file:
            pkl.dump(mip_to_dgl, file)

        if return_dict:
            return mip_to_dgl

    def get_train_list(self, path_list) -> List[List[Any]]:

        """
        path_list : List of paths of mip_to_dgl dictionaries generated by calling get_instance_dict()

        """

        train_list = []

        for p in path_list:
            with open(p, 'rb') as file:
                mip_to_dgl = pkl.load(file)
            for inst_name in mip_to_dgl:
                train_list.append(mip_to_dgl[inst_name])

            return train_list
