# -*- coding: utf-8 -*-

import argparse

from pathlib import Path
from random import shuffle

import pyscipopt  # noqa: F401
import ecole
from tqdm import tqdm


class Data:
    def __init__(self):
        """
        Data object, files are written in folder
        and raw data (inflow, price, architecture) are read from source.
        """
        self.folder = "../data/instances/"
        Path(self.folder).mkdir(exist_ok=True)

class RandomData(Data):
    def __init__(self, n_instances: int):
        """
        Data object, files are written in folder
        and raw data (inflow, price, architecture) are read from source.
        """
        super().__init__()
        self.optimize = False
        self.n_instances = n_instances
        self.instance_types = [
            "SetCovering",
            "CombinatorialAuction",
            "MaximumIndependentSet",
        ]

    def generate_benchmark(self):
        SC_dict = {"easy": (500, 1000), "medium": (1000, 1000)}
        CA_dict = {"easy": (100, 500), "medium": (200, 1000)}
        MIS_dict = {"easy": (500, 0), "medium": (1000, 0)}
        dict_params = {
            "SetCovering": SC_dict,
            "CombinatorialAuction": CA_dict,
            "MaximumIndependentSet": MIS_dict,
        }
        for instance_type in self.instance_types:
            for level in ["easy", "medium"]:
                a, b = dict_params[f"{instance_type}"][f"{level}"]
                if instance_type == "SetCovering":
                    generator = ecole.instance.SetCoverGenerator(n_rows=a, n_cols=b)
                elif instance_type == "CombinatorialAuction":
                    generator = ecole.instance.CombinatorialAuctionGenerator(n_items=a, n_bids=b)
                elif instance_type == "MaximumIndependentSet":
                    generator = ecole.instance.IndependentSetGenerator(n_nodes=a)
                else:
                    raise ValueError("Unknown instance type.")

                instance_list = []

                for _ in tqdm(range(self.n_instances), desc="Get raw data"):
                    self.model = next(generator)
                    instance_list += [self.model.as_pyscipopt()]

                shuffle(instance_list)
                print("Len raw data :", len(instance_list))
                self.train_folder = self.folder + f"{instance_type}/train/{level}/"
                Path(self.train_folder).mkdir(exist_ok=True)
                print(f"\nWrite Lp for {self.train_folder}.")
                for idx_instance, instance in enumerate(instance_list):
                    self.save_file = f"{self.train_folder}instance_{idx_instance}.lp"
                    if self.save_file is not None:
                        instance.writeProblem(self.save_file)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_instances", default=1000, type=int)
    args = parser.parse_args()

    dataset = RandomData(n_instances=args.n_instances)
    dataset.generate_benchmark()
