import torch
from contextlib import nullcontext
import time
import timeout_decorator
import sympy as sp
import logging

from alpha_integrate.train.timeout_error import TimeoutErr
from alpha_integrate.synthetic_data.expression_tokenizer import ExpressionTokenizer
from alpha_integrate.synthetic_data.process_action import apply_step
from alpha_integrate.synthetic_data.tokenization import WORD2ID, START, PAD, END, SUBEXPR, RULE, PARAM1, PARAM2
from alpha_integrate.synthetic_data.tokenization import tokenize, detokenize
from alpha_integrate.synthetic_data.params.step_params import RULE_TOKENS, PARAM2_RULES

e = ExpressionTokenizer()

VOCAB_SIZE = len(WORD2ID)
START_TOKEN = WORD2ID[START]
END_TOKEN = WORD2ID[END]
PAD_TOKEN = WORD2ID[PAD]
SUBEXPR_TOKEN = WORD2ID[SUBEXPR]
RULE_TOKEN = WORD2ID[RULE]
PARAM1_TOKEN = WORD2ID[PARAM1]
PARAM2_TOKEN = WORD2ID[PARAM2]
INTEGRAL_TOKEN = WORD2ID[INTEGRAL]

@timeout_decorator.timeout(1)
def seq_to_sp_timeout(expression):
    return e.seq_to_sp(expression)

def integration_variable_names(sp_expr: sp.Expr) -> dict:

    if sp_expr.is_Atom:
        return dict()
    
    names = dict()
    if isinstance(sp_expr, sp.Integral):
        for phi in sp_expr.args[1]:
            names[phi.name] = True
        for key, value in integration_variable_names(sp_expr.args[0]).items():
            names[key] = True
        return names

    for arg in sp_expr.args:
        for key, value in integration_variable_names(arg).items():
            names[key] = True

    names['x'] = True
    
    return names

def parse_sample(sample: List[str]):

    if len(sample) < 1 or sample[-1] != END:
        return None

    try:
        rule_idx = sample.index(RULE)
    except ValueError:
        return None
    
    subexpr = sample[:rule_idx]
    try:
        subexpr_sp = seq_to_sp_timeout(subexpr)
    except:
        return None

    remaining = sample[rule_idx+1:]
    
    if len(remaining) < 2:
        return None

    rule = remaining[0]

    if rule not in RULE_TOKENS:
        return None
    
    if rule not in PARAM2_RULES:
        if remaining[1] == END:
            return subexpr_sp, (rule,)
        else:
            return None
    
    if remaining[1] != PARAM1:
        return None

    remaining = remaining[2:]
    
    try:
        param2_idx = remaining.index(PARAM2)
    except ValueError:
        return None
    
    param1 = remaining[:param2_idx]
    try: 
        param1_sp = seq_to_sp_timeout(param1)
    except:
        return None

    remaining = remaining[param2_idx+1:-1]

    try:
        param2_sp = seq_to_sp_timeout(remaining)
    except:
        return None

    return subexpr_sp, (rule, param1_sp, param2_sp)

def try_to_integrate(expression: List[str], model, block_size, device, beam_size, temperature, top_k, 
    iters = 40, expr_dict = dict(), variable_list = [], timeout_time = None, debug = False):

    ctx = nullcontext()

    # manual timeout
    if timeout_time is not None and time.time() > timeout_time:
        raise TimeoutErr

    # if the expression is too long to feed the model, can't solve
    if len(expression) > (block_size // 2) - 12:
        return [], False, expr_dict

    # if we have reached the maximum number of iterations, can't solve
    if iters == 0:
        return [], False, expr_dict

    try:
        sp_expr = seq_to_sp_timeout(expression)
    except:
        return [], False, expr_dict

    if len(variable_list) == 0:
        for s in sp_expr.free_symbols:
            variable_list.append((s.name, s))
    
    expr_dict[sp_expr] = True 

    if debug:
        logging.info(f"Iter: {iters}, Expression: {sp_expr}")

    tokenized_expression = [START_TOKEN] + [WORD2ID[w] for w in expression] + [SUBEXPR_TOKEN]
    x = torch.tensor(tokenized_expression, dtype=torch.long, device=device)

    with torch.no_grad():
        with ctx:
            samples = model.generate_beam(x, block_size - len(tokenized_expression) - 1, beam_size, END_TOKEN, PAD_TOKEN, temperature=temperature, top_k=top_k).cpu()
            next_expressions = []
            subexpressions = []
            rules = []
            variable_lists = []
            for i,y in enumerate(samples):
                #print(y.tolist())
                sample = y.tolist()[len(tokenized_expression):]
                try:
                    endtoken_idx = sample.index(END_TOKEN)
                except ValueError:
                    continue
                sample = sample[:endtoken_idx+1]
                sample = detokenize(sample)
                #logging.info(f"Sample {i}: {sample}")
                parsed = parse_sample(sample)
                if parsed is None:
                    continue
                subexpr, rule = parsed
                try:
                    next_exprs = apply_step((sp_expr, subexpr, rule))
                except:
                    continue
                if next_exprs is None:
                    continue

                free_symbol_names = [phi.name for phi in sp_expr.free_symbols]
                variable_list_names = [varchange[0] for varchange in variable_list]

                if rule[0] == 'URule':
                    u_var = rule[1]
                    u_func = rule[2]
                    if not isinstance(u_var, sp.Symbol):
                        continue
                    if u_var.name in variable_list_names or u_var.name in free_symbol_names:
                        continue
                    new_variable_list = variable_list.copy()
                    new_variable_list.append((u_var.name, u_func))
                else:
                    new_variable_list = variable_list
                
                if debug:
                    logging.info(f"Rule: {rule}")
                    logging.info(f"Subexpr: {subexpr}")

                for i, nexpr in enumerate(next_exprs):
                    if nexpr not in next_expressions:
                        if debug:
                            logging.info(f"Next expr {i}: {nexpr}")
                        next_expressions.append(nexpr)
                        subexpressions.append(subexpr)
                        rules.append(rule)
                        variable_lists.append(new_variable_list)                
    
    if debug:
        logging.info(f"Next Expressions: {next_expressions}")
        logging.info('-----------------')

    #print(variable_lists)
    # check if any of the next_expressions is solved, i.e. does not have integral in it
    solutions = []
    solved_arr = [not next_expr.has(sp.Integral) for next_expr in next_expressions]
    solved = any(solved_arr)
    for i, next_expr in enumerate(next_expressions):
        if debug:
            logging.info(f"Next expr {i}: {next_expr}")

        # in each next_expr, we backsubstitute all variables that are not integration variables anymore
        int_symbol_names = list(integration_variable_names(next_expr).keys())
        free_symbol_names = [phi.name for phi in next_expr.free_symbols]

        #print(free_symbol_names)

        if not len(free_symbol_names) == len(int_symbol_names):
            #logging.info(f"Integral symbol names: {int_symbol_names}")
            #logging.info(f"Free symbol names: {free_symbol_names}")
            #logging.info(f"Before Substitution: {next_expr}")
            for name in free_symbol_names:
                if name in int_symbol_names:
                    continue

                for varchange in variable_lists[i][::-1]:
                    # wait until we find name, then backsubstitute next_expr until you reach a integration symbol
                    if varchange[0] == name:
                        next_expr = next_expr.subs(varchange[0], varchange[1])
                        name = [phi.name for phi in varchange[1].free_symbols][0]
                        if name in int_symbol_names:
                            break

            new_variable_list = [varchange for varchange in variable_lists[i] if varchange[0] in int_symbol_names]
            variable_lists[i] = new_variable_list
    
            #logging.info(f"After Substitution: {next_expr}")

        this_step = [sp_expr, subexpressions[i], rules[i], next_expr]

        if solved_arr[i]:
            #print(f"Solved: {next_expr}")
            solutions.append([this_step])
            
        elif not solved:
            try:
                next_expr_ls = e.sp_to_seq(next_expr)
            except:
                continue

            if next_expr in expr_dict:
                continue
            expr_dict[next_expr] = True
            try:
                next_steps, next_solved, expr_dict = try_to_integrate(next_expr_ls, model, block_size, device, 
                beam_size, temperature, top_k, iters-1, expr_dict, variable_lists[i], timeout_time)
                if debug:
                    logging.info(f"back to {sp_expr}")
            except TimeoutErr:
                return [], False, expr_dict
            if next_solved:
                solved = True
                for solution in next_steps:
                    full_solution = [this_step] + solution
                    solutions.append(full_solution)

    return solutions, solved, expr_dict