"""
Sample from a trained model
"""
import os
from contextlib import nullcontext
import torch
import numpy as np
import io
import cloudpickle as pickle
from typing import List
import logging
import sympy as sp
import timeout_decorator
import time
import random
import argparse

from alpha_integrate.train.model import GPTConfig, GPT
from alpha_integrate.train.config.train_integrate import block_size
from alpha_integrate.synthetic_data.tokenization import tokenize, detokenize
from alpha_integrate.synthetic_data.tokenization import WORD2ID, START, PAD, END, SUBEXPR, RULE, PARAM1, PARAM2
from alpha_integrate.synthetic_data.expression_tokenizer import ExpressionTokenizer
from alpha_integrate.synthetic_data.params.step_params import RULE_TOKENS, PARAM2_RULES
from alpha_integrate.synthetic_data.params.tokenizer_params import INTEGRAL, SYMBOLS
from alpha_integrate.synthetic_data.process_action import apply_step
from alpha_integrate.train.timeout_error import TimeoutErr

node_count = 0

def main(args):
    # -----------------------------------------------------------------------------
    e = ExpressionTokenizer()

    @timeout_decorator.timeout(1)
    def seq_to_sp_timeout(seq):
        return e.seq_to_sp(seq)

    model = args.model
    expr_file = args.expr_file

    # get the expression from the file
    with io.open(expr_file, mode='r', encoding='utf-8') as f:
        for l in f:
            line = l.split()
            break
    
    expr = [INTEGRAL] + line + [SYMBOLS['x']]
    sp_expr = seq_to_sp_timeout(expr)

    # -----------------------------------------------------------------------------
    init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
    out_dir = 'alpha_integrate/train/out-integrate' # ignored if init_from is not 'resume'
    # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
    temperature = 1.0
    beam_size = int(args.beam_size)
    #print(beam_size)
    top_k = 50 # retain only the top_k most likely tokens, clamp others to have 0 probability
    seed = 1337
    DEBUG = False
    device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
    dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
    compile = False # use PyTorch 2.0 to compile the model to be faster
    exec(open('alpha_integrate/train/configurator.py').read()) # overrides from command line or config file
    # -----------------------------------------------------------------------------

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
    device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
    ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

    # model
    if init_from == 'resume':
        # init from a model saved in a specific directory
        ckpt_path = os.path.join(out_dir, f'{model}.pt')
        checkpoint = torch.load(ckpt_path, map_location=device)
        gptconf = GPTConfig(**checkpoint['model_args'])
        model = GPT(gptconf)
        state_dict = checkpoint['model']
        unwanted_prefix = '_orig_mod.'
        for k,v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
        model.load_state_dict(state_dict)
    elif init_from.startswith('gpt2'):
        # init from a given GPT-2 model
        model = GPT.from_pretrained(init_from, dict(dropout=0.0))

    model.eval()
    model.to(device)
    if compile:
        model = torch.compile(model) # requires PyTorch 2.0 (optional)

    VOCAB_SIZE = len(WORD2ID)
    START_TOKEN = WORD2ID[START]
    END_TOKEN = WORD2ID[END]
    PAD_TOKEN = WORD2ID[PAD]
    SUBEXPR_TOKEN = WORD2ID[SUBEXPR]
    RULE_TOKEN = WORD2ID[RULE]
    PARAM1_TOKEN = WORD2ID[PARAM1]
    PARAM2_TOKEN = WORD2ID[PARAM2]
    INTEGRAL_TOKEN = WORD2ID[INTEGRAL]
    X_TOKEN = WORD2ID[SYMBOLS['x']]

    def integration_variable_names(sp_expr: sp.Expr) -> dict:

        if sp_expr.is_Atom:
            return dict()
        
        names = dict()
        if isinstance(sp_expr, sp.Integral):
            for phi in sp_expr.args[1]:
                names[phi.name] = True
            for key, value in integration_variable_names(sp_expr.args[0]).items():
                names[key] = True
            return names

        for arg in sp_expr.args:
            for key, value in integration_variable_names(arg).items():
                names[key] = True
        
        return names

    def parse_sample(sample: List[str]):

        if len(sample) < 1 or sample[-1] != END:
            return None

        try:
            rule_idx = sample.index(RULE)
        except ValueError:
            return None
        
        subexpr = sample[:rule_idx]
        try:
            subexpr_sp = seq_to_sp_timeout(subexpr)
        except:
            return None

        remaining = sample[rule_idx+1:]
        
        if len(remaining) < 2:
            return None

        rule = remaining[0]

        if rule not in RULE_TOKENS:
            return None
        
        if rule not in PARAM2_RULES:
            if remaining[1] == END:
                return subexpr_sp, (rule,)
            else:
                return None
        
        if remaining[1] != PARAM1:
            return None

        remaining = remaining[2:]
        
        try:
            param2_idx = remaining.index(PARAM2)
        except ValueError:
            return None
        
        param1 = remaining[:param2_idx]
        try: 
            param1_sp = seq_to_sp_timeout(param1)
        except:
            return None

        remaining = remaining[param2_idx+1:-1]

        try:
            param2_sp = seq_to_sp_timeout(remaining)
        except:
            return None

        return subexpr_sp, (rule, param1_sp, param2_sp)


    def try_to_integrate(expression: List[str], model, block_size, device, beam_size, 
        temperature, top_k, iters = 40, expr_dict = dict(), variable_list = [], timeout_time = None):

        global node_count
        node_count += 1

        if timeout_time is not None and time.time() > timeout_time:
            raise TimeoutErr
        if DEBUG:
            logging.info("Called")
        #logging.info(f"Called on expression: {expression} with iter {iters}")
        if len(expression) > (block_size // 2) - 12:
            return [], False, expr_dict

        if iters == 0:
            return [], False, expr_dict

        try:
            sp_expr = seq_to_sp_timeout(expression)
            #sp_expr = sp.Integral(sp_expr, sp.Symbol('x'))
        except:
            return [], False, expr_dict

        if len(variable_list) == 0:
            for s in sp_expr.free_symbols:
                variable_list.append((s.name, s))
        
        expr_dict[sp_expr] = True        
        if DEBUG:
            logging.info(f"Iter: {iters}, Expression: {sp_expr}")
        tokenized_expression = [START_TOKEN] + [WORD2ID[w] for w in expression] + [SUBEXPR_TOKEN]
        x = torch.tensor(tokenized_expression, dtype=torch.long, device=device)

        with torch.no_grad():
            with ctx:
                samples = model.generate_beam(x, block_size - len(tokenized_expression) - 1, beam_size, END_TOKEN, PAD_TOKEN, temperature=temperature, top_k=top_k).cpu()
                next_expressions = []
                subexpressions = []
                rules = []
                variable_lists = []
                for i,y in enumerate(samples):
                    #print(y.tolist())
                    sample = y.tolist()[len(tokenized_expression):]
                    try:
                        endtoken_idx = sample.index(END_TOKEN)
                    except ValueError:
                        continue
                    sample = sample[:endtoken_idx+1]
                    sample = detokenize(sample)
                    logging.info(f"Sample {i}: {sample}")
                    parsed = parse_sample(sample)
                    if parsed is None:
                        continue
                    subexpr, rule = parsed
                    try:
                        next_exprs = apply_step((sp_expr, subexpr, rule))
                    except:
                        continue
                    if next_exprs is None:
                        continue

                    free_symbol_names = [phi.name for phi in sp_expr.free_symbols] 
                    variable_list_names = [varchange[0] for varchange in variable_list]

                    if rule[0] == 'URule':
                        u_var = rule[1]
                        u_func = rule[2]
                        if not isinstance(u_var, sp.Symbol):
                            continue
                        if u_var.name in variable_list_names or u_var.name in free_symbol_names:
                            continue
                        new_variable_list = variable_list.copy()
                        new_variable_list.append((u_var.name, u_func))
                    else:
                        new_variable_list = variable_list
                    
                    if DEBUG:
                        logging.info(f"Rule: {rule}")
                        logging.info(f"Subexpr: {subexpr}")

                    for i, nexpr in enumerate(next_exprs):
                        if nexpr not in next_expressions:
                            if DEBUG:
                                logging.info(f"Next expr {i}: {nexpr}")
                            next_expressions.append(nexpr)
                            subexpressions.append(subexpr)
                            rules.append(rule)
                            variable_lists.append(new_variable_list)                
                    

                        #logging.info(f"Next exprs type: {type(next_expressions)}")
                        #for a in next_expressions:
                        #    logging.info(f"Next expr: {a}")
                        #logging.info(f"Next Expressions from this: {next_exprs}")
                        #logging.info('-----------------')

        
        if DEBUG:
            logging.info(f"Next Expressions: {next_expressions}")
            logging.info('-----------------')

        #print(variable_lists)
        # check if any of the next_expressions is solved, i.e. does not have integral in it
        solutions = []
        solved_arr = [not next_expr.has(sp.Integral) for next_expr in next_expressions]
        solved = any(solved_arr)
        for i, next_expr in enumerate(next_expressions):
            if DEBUG:
                logging.info(f"Next expr {i}: {next_expr}")

            # in each next_expr, we backsubstitute all variables that are not integration variables anymore

            if DEBUG:
                logging.info(f"Before Substitution: {next_expr}")

            new_variable_list = variable_lists[i].copy()
            num_removed = 0
            for varchange in variable_lists[i][::-1]:
                int_symbol_names = list(integration_variable_names(next_expr).keys())
                free_symbol_names = [phi.name for phi in next_expr.free_symbols]
                stop = True
                if varchange[0] in free_symbol_names:
                    if varchange[0] not in int_symbol_names:
                        next_expr = next_expr.subs(varchange[0], varchange[1])
                        stop = False
                        num_removed += 1
                if stop:
                    break
            
            new_variable_list = new_variable_list[:len(new_variable_list) - num_removed]
            variable_lists[i] = new_variable_list
            
            if DEBUG:
                logging.info(f"After Substitution: {next_expr}")
                logging.info(f"Variable List: {variable_lists[i]}")
                logging.info('-----------------')

            this_step = [sp_expr, subexpressions[i], rules[i], next_expr]

            if solved_arr[i]:
                #print(f"Solved: {next_expr}")
                solutions.append([this_step])
                
            elif not solved:
                try:
                    next_expr_ls = e.sp_to_seq(next_expr)
                except:
                    continue

                if next_expr in expr_dict:
                    continue
                expr_dict[next_expr] = True
                try:
                    next_steps, next_solved, expr_dict = try_to_integrate(next_expr_ls, model, block_size, device, 
                    beam_size, temperature, top_k, iters-1, expr_dict, variable_lists[i])
                    if DEBUG:
                        logging.info(f"back to {sp_expr}")
                except TimeoutErr:
                    return [], False, expr_dict
                if next_solved:
                    solved = True
                    for solution in next_steps:
                        full_solution = [this_step] + solution
                        solutions.append(full_solution)

        return solutions, solved, expr_dict

    # -----------------------------------------------------------------------------
    # Run the evaluation

    r = dict()
    global node_count
    node_count = 0
    solutions, solved, expr_dict = try_to_integrate(expr, model, block_size, device, beam_size, temperature, top_k)
    if solved:
        # append f'{node_count}\n' to the file 'alpha_integrate/train/transformer_tree.txt'
        with open('alpha_integrate/train/transformer_tree.txt', 'a') as f:
            f.write(f'{node_count}\n')

        



if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--expr_file', type=str, default=None)
    parser.add_argument('--model', type=str, default=None)
    parser.add_argument('--beam_size', type=int, default=None)

    args = parser.parse_args()

    logging.basicConfig(filename = f'alpha_integrate/train/evaluatelogs/transformer/ttreecount_{args.model}.txt', filemode='w', level=logging.INFO)

    try:
        main(args)
    except TimeoutErr:
        pass
    except:
        # get expression from the file
        with io.open(args.expr_file, mode='r', encoding='utf-8') as f:
            for l in f:
                line = l.split()
                break
        logging.info(f"Had different error for {line}")