

# This file was *autogenerated* from the file src/data/dataset.sage
from sage.all_cmdline import *   # import sage library

_sage_const_123456789 = Integer(123456789); _sage_const_0 = Integer(0); _sage_const_1 = Integer(1); _sage_const_1p0 = RealNumber('1.0'); _sage_const_2 = Integer(2); _sage_const_100 = Integer(100); _sage_const_10 = Integer(10); _sage_const_20 = Integer(20); _sage_const_4 = Integer(4); _sage_const_3 = Integer(3)
import itertools as it
from time import time

# import matplotlib.pyplot as plt
import pickle
import argparse
import os
import numpy as np
from joblib import Parallel, delayed
from joblib import wrap_non_picklable_objects
import sys
import sage.misc.randstate as randstate
import joblib
from pprint import pprint

# randstate.set_random_seed(os.getpid())
load("src/data/symbolic_utils.sage")

np.random.seed((os.getpid() * int(time())) % _sage_const_123456789 )


def max_num(p):
    coeffs = p.coefficients()
    field = p.base_ring()

    if field == RR:
        return max([abs(c) for c in coeffs]) if len(coeffs) else _sage_const_0 
    else:
        return max([max(abs(c.numer()), abs(c.denom())) for c in coeffs]) if len(coeffs) else _sage_const_0 


def max_num_of_matrix(M):
    return max([max_num(p) for p in M.list()])


class Polynomial_Sampler:
    def __init__(
        self, ring, degree_sampling="uniform", term_sampling="uniform", strictly_conditioned=True, conditions={}
    ):
        self.ring = ring
        self.field = ring.base_ring()
        self.degree_sampling = degree_sampling
        self.term_sampling = term_sampling
        self.strictly_conditioned = strictly_conditioned
        self.conditions = conditions

    def sample(self, num_samples=_sage_const_1 , size=None, density=_sage_const_1p0 , matrix_type=None):
        max_degree = self.conditions["max_degree"]
        max_num_terms = self.conditions["max_num_terms"]
        min_degree = self.conditions["min_degree"] if "min_degree" in self.conditions else _sage_const_0 
        max_coeff = self.conditions["max_coeff"] if "max_coeff" in self.conditions else None
        num_bound = self.conditions["num_bound"] if "num_bound" in self.conditions else None
        nonzero_instance = self.conditions["nonzero_instance"] if "nonzero_instance" in self.conditions else False

        if isinstance(size, tuple):
            assert len(size) == _sage_const_2 
            return [
                self._sample_matrix(
                    self.ring,
                    size,
                    max_degree,
                    max_num_terms,
                    min_degree=min_degree,
                    max_coeff=max_coeff,
                    num_bound=num_bound,
                    strictly_conditioned=self.strictly_conditioned,
                    degree_sampling=self.degree_sampling,
                    term_sampling=self.term_sampling,
                    matrix_type=matrix_type,
                    density=density,
                    nonzero_instance=nonzero_instance,
                )
                for _ in range(num_samples)
            ]
        else:
            return [
                self._sample(
                    self.ring,
                    max_degree,
                    max_num_terms,
                    min_degree=min_degree,
                    max_coeff=max_coeff,
                    num_bound=num_bound,
                    strictly_conditioned=self.strictly_conditioned,
                    degree_sampling=self.degree_sampling,
                    term_sampling=self.term_sampling,
                    nonzero_instance=nonzero_instance,
                )
                for _ in range(num_samples)
            ]

    def _sample_matrix(
        self,
        ring,
        size,
        max_degree,
        max_num_terms,
        min_degree=_sage_const_0 ,
        max_coeff=None,
        num_bound=None,
        strictly_conditioned=True,
        degree_sampling=None,
        term_sampling=None,
        matrix_type=None,
        density=_sage_const_1p0 ,
        nonzero_instance=False,
        max_iters=_sage_const_100 ,
    ):

        num_polys = prod(size)
        for i in range(max_iters):
            P = [
                self._sample(
                    ring,
                    max_degree,
                    max_num_terms,
                    min_degree=min_degree,
                    max_coeff=max_coeff,
                    num_bound=num_bound,
                    strictly_conditioned=strictly_conditioned,
                    degree_sampling=degree_sampling,
                    term_sampling=term_sampling,
                    max_iters=max_iters,
                )
                for _ in range(num_polys)
            ]
            P = [p if random() < density else p * _sage_const_0  for p in P]
            if not nonzero_instance or not all([p == _sage_const_0  for p in P]):
                break

            if i == max_iters - _sage_const_1 :
                raise RuntimeError(f"Failed to find a nonzero polynomial with {max_iters} iterations")

        A = matrix(ring, *size, P)
        if matrix_type == "unimoduler_upper_triangular":
            for i, j in it.product(range(size[_sage_const_0 ]), range(size[_sage_const_1 ])):
                if i == j:
                    A[i, j] = _sage_const_1 
                if i < j:
                    A[i, j] = _sage_const_0 

        return A

    def _sample(
        self,
        ring,
        max_degree,
        max_num_terms,
        min_degree=_sage_const_0 ,
        max_coeff=None,
        num_bound=None,
        strictly_conditioned=True,
        degree_sampling="uniform",
        term_sampling="uniform",
        nonzero_instance=False,
        max_iters=_sage_const_100 ,
    ):

        degree = randint(min_degree, max_degree) if degree_sampling == "uniform" else max_degree
        max_num_terms = min(max_num_terms, binomial(degree + ring.ngens(), degree))
        num_terms = randint(_sage_const_1 , max_num_terms) if term_sampling == "uniform" else max_num_terms
        ngens = ring.ngens()

        for i in range(max_iters):
            coeff_ring = ring.base_ring()
            # degree_ = degree + 1 if ngens == 1 else degree  # NOTE: When univariate, the degree options seems work as exclusive. It is a bug in SageMath?
            if coeff_ring == QQ:
                p = ring.random_element(degree=degree, terms=num_terms, num_bound=num_bound, choose_degree=True)
            elif coeff_ring == RR:
                p = ring.random_element(
                    degree=degree, terms=num_terms, min=-max_coeff, max=max_coeff, choose_degree=True
                )
            elif coeff_ring == ZZ:
                p = ring.random_element(
                    degree=degree, terms=num_terms, x=-max_coeff, y=max_coeff + _sage_const_1 , choose_degree=True
                )
            else:
                assert coeff_ring.field() < Infinity
                p = ring.random_element(degree=degree, terms=num_terms, choose_degree=True)

            if p == _sage_const_0  and nonzero_instance:
                continue
            if not strictly_conditioned:
                break
            if p.total_degree() == degree and len(p.monomials()) == num_terms:
                break

            # print(p.total_degree(), len(p.monomials()))

            if i == max_iters - _sage_const_1 :
                print(f"conditions: degree={degree}, num_terms={num_terms}")
                raise RuntimeError(f"Failed to find a polynomial satisfying the conditions with {max_iters} iterations")

        assert p != _sage_const_0 
        return p

    # def __sample(ring, degree, terms, num_bound=None, min=None, max_coeff):
    #     coeff_ring = ring.base_ring()


class Dataset_Builder:
    def __init__(
        self,
        ring,
        max_coeff=_sage_const_10 ,
        max_degree=_sage_const_10 ,
        max_num_terms=_sage_const_20 ,
    ):
        self.ring = ring
        self.max_coeff = max_coeff  # in case field is not finite; the
        self.max_degree = max_degree
        self.max_num_terms = max_num_terms
        self.num_tasks = _sage_const_4 

    def get_sample(self, psampler, seed=-_sage_const_1 ):
        ## joblib with multiprocesssing cannot use identical random states at the begining.
        ## 100 is supposed to be larger than the core numbers
        if seed > -_sage_const_1 :
            randstate.set_random_seed(seed)
        return psampler.sample(num_samples=_sage_const_1 )[_sage_const_0 ]

    def run(self, num_samples, n_jobs=-_sage_const_1 ):
        self.n_jobs = n_jobs

        # max_num_terms = binomial(self.ring.ngens() + self.max_degree, self.max_degree)

        conditions = {
            "max_degree": self.max_degree,
            "min_degree": _sage_const_1 ,
            "max_num_terms": self.max_num_terms,
            "max_coeff": self.max_coeff,
            "num_bound": self.max_coeff,
            "nonzero_instance": True,
        }

        psampler = Polynomial_Sampler(
            self.ring,
            degree_sampling="uniform",
            term_sampling="uniform",
            strictly_conditioned=True,
            conditions=conditions,
        )

        samples = Parallel(n_jobs=n_jobs, backend="multiprocessing", verbose=True)(
            delayed(self.get_sample)(psampler, seed=i) for i in range(_sage_const_4  * num_samples)
        )

        samples_aux = Parallel(n_jobs=n_jobs, backend="multiprocessing", verbose=True)(
            delayed(self.get_sample)(psampler, seed=i + _sage_const_2  * num_samples) for i in range(num_samples)
        )

        probs = []
        for i in range(num_samples):
            f1, f2, g1, g2 = samples[_sage_const_4  * i], samples[_sage_const_4  * i + _sage_const_1 ], samples[_sage_const_4  * i + _sage_const_2 ], samples[_sage_const_4  * i + _sage_const_3 ]
            probs += [(f1, f2, g1, g2, (f1 + f2) * (g1 + g2))]
            # f, g, aux = samples[2 * i], samples[2 * i + 1], samples_aux[i]

            # probs += [(f, g, f + g), (f, g, f * g)]
            # q1, q2 = f * aux, g * aux
            # probs += [(q1, q2, gcd(q1, q2))]
            # q = f * g + aux
            # probs += [(q, f, q.quo_rem(g)[1] if g != 0 else g - g)]

        # self.dataset = [([p1, p2], [q]) for p1, p2, q in probs]
        self.dataset = [([p1, p2, p3, p4], [q]) for p1, p2, p3, p4, q in probs]
        self.stats = (
            [self.get_system_stat(F) for F, _ in self.dataset],
            [self.get_system_stat(G) for _, G in self.dataset],
        )

        return self

    def get_system_stat(self, P):
        """
        P: set of polynomials (column vector with polynomial entries)
        """
        if not isinstance(P, list):
            P = P.list()
        size = len(P)
        degrees = [int(p.total_degree()) for p in P]
        num_monoms = [int(len(p.monomials())) for p in P]

        # max_coeff = max([max([abs(c) for c in p.coefficients()]) for p in P])
        field = P[_sage_const_0 ].base_ring()
        coeffs = it.chain(*[p.coefficients() for p in P])
        if field == QQ:
            coeffs = [c.numer() for c in coeffs] + [c.denom() for c in coeffs]
        if field.characteristic() == _sage_const_0 :
            coeffs = [abs(c) for c in coeffs]

        max_coeff = max(coeffs, default=-_sage_const_1 )

        stat = {
            "size": int(size),
            "degrees": degrees,
            "num_monoms": num_monoms,
            "min_degree": float(np.min(degrees)),
            "min_num_monoms": float(np.min(num_monoms)),
            "max_degree": float(np.max(degrees)),
            "max_num_monoms": float(np.max(num_monoms)),
            "total_num_monoms": float(np.sum(degrees)),
            "max_coeff": int(max_coeff),
            # 'is_GB': int(self.check_gb(P)) if field is not RR else int(-1)
        }

        return stat


def summarize_stats(stats, metric=["mean", "std", "max", "min", "median"]):
    summary = {}
    for k in stats[_sage_const_0 ]:
        if isinstance(stats[_sage_const_0 ][k], list):
            continue
        if "mean" in metric:
            summary[f"{k}_mean"] = float(np.mean([item[k] for item in stats]))
        if "median" in metric:
            summary[f"{k}_median"] = float(np.median([item[k] for item in stats]))
        if "std" in metric:
            summary[f"{k}_std"] = float(np.std([item[k] for item in stats]))
        if "max" in metric:
            summary[f"{k}_max"] = float(np.max([item[k] for item in stats]))
        if "min" in metric:
            summary[f"{k}_min"] = float(np.min([item[k] for item in stats]))
    return summary


class Writer:
    def __init__(self, local_separator=" [SEP] ", global_separator=" : "):
        self.local_separator = local_separator
        self.global_separator = global_separator

    def _preprocess(self, F, G, encoding=None, ring=None, n_jobs=-_sage_const_1 , split_rational=True):

        Fstr = [poly_to_sequence(f, encoding=encoding) for f in F]
        Gstr = [poly_to_sequence(g, encoding=encoding) for g in G]

        Fstr = self.local_separator.join(Fstr)
        Gstr = self.local_separator.join(Gstr)
        s = Fstr + self.global_separator + Gstr

        num_tokens_F = int(len(Fstr.split()))
        num_tokens_G = int(len(Gstr.split()))
        stats = {"num_tokens_F": num_tokens_F, "num_tokens_G": num_tokens_G, "num_tokens": num_tokens_F + num_tokens_G}

        return s, stats

    def write(self, filename, dataset, n_jobs=-_sage_const_1 , encoding="raw", ring=None, split_rational=True):
        filename = f"{filename}.{encoding}"

        start = time()
        ret = Parallel(n_jobs=-_sage_const_1 , backend="multiprocessing", verbose=True)(
            delayed(self._preprocess)(F, G, encoding=encoding, ring=ring, split_rational=split_rational)
            for F, G in dataset
        )

        dataset_str, stats = zip(*ret)
        dataset_str = "\n".join(dataset_str)
        with open(filename, "w") as f:
            f.write(dataset_str)
            # f.write("\n".join(dataset_str))
            # joblib.dump(dataset_str, f, compress=False)

        summary = summarize_stats(stats)
        with open(filename + "_token_stats.yaml", "w") as f:
            yaml.dump(summary, f)

    def save_stats(self, summaries, save_path="", postfix=""):
        if postfix:
            postfix = "_" + postfix
        stats_F_sum, stats_G_sum = summaries

        print("---- dataset statistics -----------")
        print("# F")
        pprint(stats_F_sum)
        print("# G")
        pprint(stats_G_sum)
        print("")

        if save_path:
            # with open(os.path.join(save_path, f"dataset_stats{postfix}.pickle"), "w") as f:
            #     yaml.dump(stats, f)

            with open(os.path.join(save_path, f"dataset_stats_F_summary{postfix}.yaml"), "w") as f:
                yaml.dump(stats_F_sum, f)

            with open(os.path.join(save_path, f"dataset_stats_G_summary{postfix}.yaml"), "w") as f:
                yaml.dump(stats_G_sum, f)


def get_parser():
    """
    Generate a parameters parser.
    """
    # parse parameters
    parser = argparse.ArgumentParser(description="Language transfer")

    # main parameters
    parser.add_argument("--data_path", type=str, default="./data/diff_dataset", help="Experiment dump path")
    parser.add_argument("--data_encoding", type=str, default="prefix")
    parser.add_argument("--save_path", type=str, default="./dumped", help="Experiment dump path")
    parser.add_argument("--config_path", type=str, default="./config")
    parser.add_argument("--tasks", type=str, nargs="*")

    return parser


if __name__ == "__main__":

    parser = get_parser()
    params = parser.parse_args()

    # tasks = ['sum', 'prod', 'gcd', 'remainder']
    tasks = params.tasks
    save_dir = params.save_path

    print(f"## dataset geneartion with config {params.config_path} ##########")
    print(params)
    print(tasks)

    for task in tasks:
        os.makedirs(os.path.join(save_dir, task), exist_ok=True)

    import yaml

    with open(params.config_path, "r") as f:
        config = yaml.safe_load(f)

    field_name = config["field"]
    if field_name == "QQ":
        field = QQ
    elif field_name == "RR":
        field = RR
    elif field_name == "ZZ":
        field = ZZ
    elif field_name[:_sage_const_2 ] == "GF":
        order = int(field_name[_sage_const_2 :])
        field = GF(order)

    num_samples_test = config["num_samples_test"]
    num_samples_train = config["num_samples_train"]

    num_variables = int(config["num_var"])
    ring = PolynomialRing(field, num_variables, "x", order="lex")

    print(f"-- Test set ({num_samples_test} samples)")
    dbuilder = Dataset_Builder(
        ring,
        max_coeff=config["max_coeff"] if "max_coeff" in config else _sage_const_0 ,
        max_degree=config["max_degree"],
        max_num_terms=config["max_num_terms"],
    )

    n_jobs = -_sage_const_1 
    dbuilder.run(num_samples_test, n_jobs=n_jobs)

    for i, task in enumerate(tasks):
        print(f"## Task: {task} ##############################")
        testset = dbuilder.dataset[i :: len(tasks)]
        stats = (dbuilder.stats[_sage_const_0 ][i :: len(tasks)], dbuilder.stats[_sage_const_1 ][i :: len(tasks)])

        summaries = (summarize_stats(stats[_sage_const_0 ]), summarize_stats(stats[_sage_const_1 ]))

        writer = Writer()
        save_dir_task = os.path.join(save_dir, task)
        writer.write(
            os.path.join(save_dir_task, f"data_{field_name}_n={num_variables}.test"),
            testset,
            encoding="lex.infix",
            n_jobs=n_jobs,
            ring=ring,
            split_rational=True,
        )
        writer.save_stats(summaries, save_path=save_dir_task, postfix=f"data_{field_name}_n={num_variables}_test")
        print("")

    print(f"-- Train set ({num_samples_train} samples)")
    dbuilder = Dataset_Builder(
        ring,
        max_coeff=config["max_coeff"] if "max_coeff" in config else _sage_const_0 ,
        max_degree=config["max_degree"],
        max_num_terms=config["max_num_terms"],
    )

    n_jobs = -_sage_const_1 
    dbuilder.run(num_samples_train, n_jobs=n_jobs)

    for i, task in enumerate(tasks):
        print(f"## Task: {task} ##############################")
        trainset = dbuilder.dataset[i :: len(tasks)]
        stats = (dbuilder.stats[_sage_const_0 ][i :: len(tasks)], dbuilder.stats[_sage_const_1 ][i :: len(tasks)])

        summaries = (summarize_stats(stats[_sage_const_0 ]), summarize_stats(stats[_sage_const_1 ]))

        writer = Writer()
        save_dir_task = os.path.join(save_dir, task)
        writer.write(
            os.path.join(save_dir_task, f"data_{field_name}_n={num_variables}.train"),
            trainset,
            encoding="lex.infix",
            n_jobs=n_jobs,
            ring=ring,
            split_rational=True,
        )
        writer.save_stats(summaries, save_path=save_dir_task, postfix=f"data_{field_name}_n={num_variables}_train")

        print("")
    print("done!")

