"""Note: This file is more a less a single hack. No guarantees on probability distributions of generated expressions, but it should at least reaspect the resulting size"""

# pylint: disable = line-too-long
import random
import numpy as np

from deepltl.data import ltl_parser
from deepltl.data.ltl_parser import Token, LTLFormulaBinaryOp, LTLFormulaUnaryOp, LTLFormulaLeaf
from deepltl.utils import utils

def split_ops_and_dist(ops_and_dist, ts_budget, relax_lower=False):
    ops_and_dist = np.array(ops_and_dist)
    ops = ops_and_dist[:, 0]
    dist = ops_and_dist[:, 1].astype(np.float)
    b_min, b_max = ts_budget
    for i, op in enumerate(ops):
        op_min, op_max = op_budgets[op]
        if (not relax_lower and b_min > op_max) or b_max < op_min:
            dist[i] = 0 # zero possibility
    sum_ = dist.sum()
    if sum_ == 0:
        raise RuntimeError('No ops to choose from!')
    dist /= sum_
    return ops, dist

op_budgets = {'ap': (1, 1), '0': (1, 1), '1': (1, 1), '!': (2, np.inf), '&': (3, np.inf), '|': (3, np.inf), 'X': (2, np.inf), 'G': (2, np.inf), 'F': (2, np.inf), 'R': (3, np.inf), 'U': (3, np.inf)}
op_dict = ltl_parser.token_dict_spot
op_dict['ap'] = (0, Token.AP)

DEBUG = False

class FormulaClassPart:
    def gen(self, ts_budget, children_stack):
        raise NotImplementedError()

class Nexts(FormulaClassPart):
    def __init__(self, d_budget, unique=False):
        self.d_budget = d_budget
        self.unique = unique
        if unique:
            self.memory = {}

    def gen(self, ts_budget, children_stack, pre=True):
        d_min, d_max = self.d_budget
        b_min, b_max = ts_budget
        d_max = min(d_max, b_max-1)
        if DEBUG:
            print('Nexts, ts in', ts_budget, 'd in', self.d_budget, end=' ')
        assert b_min <= d_max + 1
        if pre:
            poss = list(range(d_min, d_max + 1))
            probs = np.ones(len(poss)) / len(poss)
            if self.unique:
                for i, q in enumerate(poss):
                    if q in self.memory:
                        probs[i] = 0
                if probs.sum() == 0:
                    raise RuntimeError('Cannot create new number for Nexts')
                probs /= probs.sum()
            d = np.random.choice(poss, p=probs)
            if DEBUG:
                print(f'chose d={d}')
            if self.unique:
                self.memory[d] = None
            child_budget = b_min - d, b_max - d
        else:
            raise NotImplementedError()
        next_child = children_stack[0]
        child, child_size = next_child.gen(child_budget, children_stack[1:])
        current = child
        for _ in range(d):
            current = LTLFormulaUnaryOp(Token.NEXT, current)
        return current, child_size + d

class Ands(FormulaClassPart):
    def __init__(self, d_budget, limit_children_size=1):
        self.d_budget = d_budget
        self.limit_children_size = limit_children_size

    def gen(self, ts_budget, children_stack, pre=True):
        d_min, d_max = self.d_budget
        b_min, b_max = ts_budget
        if DEBUG:
            print('Ands , ts in', ts_budget, 'd in', self.d_budget, end=' ')
        assert b_max > 0
        d_max = min(d_max, (b_max-d_max+1)/self.limit_children_size)
        d_max = max(d_max, 1)
        assert d_min >= 1
        assert d_min <= d_max
        d = np.random.randint(d_min, d_max+1)
        if DEBUG:
            print(f'chose d={d}')
        children_list = []
        children_sizes = []
        b_min_per_child, b_max_per_child = (b_min-d+1) // d, (b_max-d+1) // d
        b_min_remainder, b_max_remainder = (b_min-d+1) % d, (b_max-d+1) % d
        for i in range(d):
            b_min_child = b_min_per_child + (1 if i+1 <= b_min_remainder else 0)
            b_max_child = b_max_per_child + (1 if i+1 <= b_max_remainder else 0)
            next_child = children_stack[0]
            child, child_size = next_child.gen((b_min_child, b_max_child), children_stack[1:])
            children_list.append(child)
            children_sizes.append(child_size)
        children_perm = np.random.permutation(d)
        children_list = np.array(children_list)[children_perm]
        children_sizes = np.array(children_sizes)[children_perm]
        current = children_list[0]
        for child in children_list[1:]:
            current = LTLFormulaBinaryOp(Token.AND, current, child)
        return current, children_sizes.sum() + d-1

class AndGradually(FormulaClassPart):
    def __init__(self, min_splits=None, max_splits=1, child_min=None, child_max=None, debug=False):
        self.min_splits = min_splits
        self.max_splits = max_splits
        self.child_min = child_min
        self.child_max = child_max
        self.debug = debug

    def gen(self, ts_budget, children_stack):
        b_min, b_max = ts_budget
        assert b_min <= b_max
        children_list = []
        if self.child_min is not None:
            b_min_per_child = self.child_min
        else:
            b_min_per_child = (b_min-self.min_splits+1) // self.min_splits if self.min_splits is not None else 1
        if self.child_max is not None:
            b_max_per_child = self.child_max
        else:
            b_max_per_child = (b_max-self.max_splits+1) // self.max_splits if self.max_splits is not None else 1
        if DEBUG:
            print('AndG , ts in', ts_budget, 'child_b in', (b_min_per_child, b_max_per_child))
        # TODO: remainder?
        total_sizes = 0
        next_child = children_stack[0]
        while total_sizes-1 < max(b_min, 1): # -1 for end if finished
            next_size = b_max-total_sizes
            if next_size <= 0:
                print('fuck, too short')
                break
            if DEBUG and self.debug:
                print(f'- current size is {total_sizes} of {b_max}, next can be {next_size}')
            b_max_child = min(b_max_per_child, next_size)
            child, child_size = next_child.gen((b_min_per_child, b_max_child), children_stack[1:])
            retrials = 0
            while b_min == b_max and child_size == next_size - 1: # won't fulfil criterion
                child, child_size = next_child.gen((b_min_per_child, b_max_child), children_stack[1:]) # retry
                retrials += 1
                if retrials == 100:
                    raise RuntimeError('failed to fill AndGradually with matching children')
            children_list.append(child)
            if DEBUG and self.debug:
                print('- and was', child_size)
            total_sizes += child_size + 1 # 1 for and
        total_sizes -= 1 # for and cause finished
        d = len(children_list)
        if DEBUG:
            print(' ... AndG generated', d, 'children')
        children_perm = np.random.permutation(d)
        children_list = np.array(children_list)[children_perm]
        current = children_list[0]
        for i in range(d-1):
            current = LTLFormulaBinaryOp(Token.AND, current, children_list[i+1])
        return current, total_sizes

class SingleOp(FormulaClassPart):
    def __init__(self, token, num_children):
        self.token = token
        self.num_children = num_children
    
    def gen(self, ts_budget, children_stack):
        b_min, b_max = ts_budget
        assert self.num_children > 0
        next_child = children_stack[0]
        if self.num_children == 1:
            child, child_size = next_child.gen((b_min - 1, b_max - 1), children_stack[1:])
            return LTLFormulaUnaryOp(self.token, child), child_size + 1
        else: # num_children == 2
            lchild, lchild_size = next_child.gen((int(np.floor((b_min-1)/2)), int(np.floor((b_max-1)/2))), children_stack[1:])
            rchild, rchild_size = next_child.gen((int(np.ceil((b_min-1)/2)), int(np.ceil((b_max-1)/2))), children_stack[1:])
            if np.random.choice(2): # flip a coin
                (lchild, lchild_size), (rchild, rchild_size) = (rchild, rchild_size), (lchild, lchild_size) # swap
            return LTLFormulaBinaryOp(self.token, lchild, rchild), lchild_size + rchild_size + 1


class FormulaTree(FormulaClassPart):
    def __init__(self, ops_and_dist, ap_list, d_budget=None, relax_lower=False):
        self.ops_and_dist = ops_and_dist
        self.ap_list = ap_list
        self.relax_lower = relax_lower
        self.d_budget = d_budget

    def gen(self, ts_budget, children_stack=[]):
        assert children_stack == []
        ops, dist = split_ops_and_dist(self.ops_and_dist, ts_budget)
        b_min, b_max = ts_budget
        if DEBUG:
            #print('Ftree, ts in', ts_budget, 'd in', self.d_budget)
            pass
        if self.d_budget is not None:
            b_min = max(self.d_budget[0], b_min)
            b_max = min(self.d_budget[1], b_max)
        if b_max < b_min:
            raise ValueError('BOOM exploded, max < min')
        op = np.random.choice(ops, p=dist)
        num_children, token = op_dict[op]
        if num_children == 0:
            if b_min > 1:
                raise RuntimeError('should generate more than I can')
            if op == 'ap':
                ap = np.random.choice(self.ap_list)
            else:
                ap = None
            return LTLFormulaLeaf(token, ap=ap), 1
        if num_children == 1:
            child, child_size = self.gen((b_min - 1, b_max - 1))
            return LTLFormulaUnaryOp(token, child), child_size + 1
        else: # num_children == 2
            lchild, lchild_size = self.gen((int(np.floor((b_min-1)/2)), int(np.floor((b_max-1)/2))))
            rchild, rchild_size = self.gen((int(np.ceil((b_min-1)/2)), int(np.ceil((b_max-1)/2))))
            if np.random.choice(2): # flip a coin
                (lchild, lchild_size), (rchild, rchild_size) = (rchild, rchild_size), (lchild, lchild_size) # swap
            return LTLFormulaBinaryOp(token, lchild, rchild), lchild_size + rchild_size + 1

def placebo_ands_nexts_assignments(size_min, size_max, aps):
    ands = AndGradually(min_splits=None, max_splits=1)
    nexts = Nexts((0, int(np.ceil(size_max/4))), unique=False)
    assignment_ands = Ands((1, len(aps)))
    literal = FormulaTree([('!', 1), ('ap', 1)], aps, d_budget=(1, 2))
    stack = [nexts, assignment_ands, literal]
    return ands.gen((size_min, size_max), stack)


import spot

def spot_get_trace(formula_str):
    spot_formula = spot.formula(formula_str)
    automaton = spot_formula.translate()
    automaton.merge_edges()
    acc_run = automaton.accepting_run()
    if acc_run is None:
        return False, None
    else:
        trace = spot.twa_word(acc_run)
        return True, str(trace)

def gen_formula_and_trace(gen_function, ts_interval, count_per_size, num_aps, timeout, calc_trace, require_trace, splits=None):
    if num_aps > 26:
        raise ValueError("Cannot generate more than 26 APs")
    aps = list(map(chr, range(97, 97 + num_aps)))
    spot_worker = utils.PersistentWorker()
    formulas = [] if splits is None else [[] for _ in splits]
    traces = [] if splits is None else [[] for _ in splits]
    total_count_gen = 0

    ts_min, ts_max = ts_interval
    for ts in range(ts_min, ts_max+1):
        formulas_step = []
        traces_step = []
        memory = {}
        count_gen = 0
        gen_retrials = 0
        max_cal_retrials = 5
        calc_retrials = 0
        while count_gen < count_per_size and gen_retrials < 100 and calc_retrials < max_cal_retrials:
            formula, formula_size = gen_function(ts, ts, aps)
            res_str = formula.to_str('spot')
            if res_str in memory:
                gen_retrials += 1
                continue
            memory[res_str] = None
            gen_retrials = 0

            if calc_trace:
                finished, res = spot_worker.call(spot_get_trace, (res_str,), timeout)
                if not finished:
                    print('TIMEOUT!!')
                    if require_trace:
                        calc_retrials +=1
                        continue
                    else:
                        trace = '-'
                else: # finished
                    calc_retrials = 0
                    sat, trace = res
                    assert sat is not None
                    if not sat: # currently, disallow unsat
                        continue
                    else: # sat
                        assert trace is not None
                traces_step.append(ltl_parser.ltl_trace(trace, 'spot').to_str('network-polish') if trace != '-' else '-')
            formulas_step.append(formula.to_str('network-polish'))
            count_gen += 1
        if gen_retrials > 0:
            print(f'failed to find (syntactically) distinct formulas for ts {ts}. Consider increasing ts or decreasing desired count per size')
        if calc_retrials > 0:
            print(f'aborted trace calculation after {max_cal_retrials} consecutive timeouts on different formulas of size {ts}. Consider decreasing formula size or difficulty')
        if splits is None:
            formulas.extend(formulas_step)
            traces.extend(traces_step)
        else:
            random.shuffle(formulas_step)
            random.shuffle(traces_step)
            splits = np.array(splits)
            splits /= splits.sum()
            accu = 0
            for i, split in enumerate(splits):
                formulas[i].extend(formulas_step[int(accu*count_gen):int((accu+split)*count_gen)])
                traces[i].extend(traces_step[int(accu*count_gen):int((accu+split)*count_gen)])
                accu += split
        total_count_gen += count_gen

    print(f'Generated a total of {total_count_gen} formulas from size {ts_min} to {ts_max}')
    spot_worker.terminate()
    return formulas, traces

if __name__ == '__main__':
    #splits = [5/6, 1/12, 1/12]
    splits = None
    formulas, traces = gen_formula_and_trace(placebo_ands_nexts_assignments, (1, 15), 500, num_aps=5, timeout=10, calc_trace=True, require_trace=True, splits=splits)
    if splits is None:
        with open('gen_classes.txt', 'w') as outfile:
            for f, t in zip(formulas, traces):
                outfile.write(f + '\n' + t + '\n')
    else:
        for i, name in enumerate(['train.txt', 'val.txt', 'test.txt']):
            with open(name, 'w') as outfile:
                for f, t in zip(formulas[i], traces[i]):
                    outfile.write(f + '\n' + t + '\n')
