"""
Sample from a trained model
"""
import os
from contextlib import nullcontext
import torch
import numpy as np
import io
import cloudpickle as pickle
from typing import List
import logging
import sympy as sp
import timeout_decorator
import time
import random
import argparse

date_time_idx = time.strftime("%Y%m%d-%H%M%S")
logging.basicConfig(filename = f'alpha_integrate/train/evaluatelogs/evaluate_{date_time_idx}.txt', filemode='w', level=logging.INFO)

from alpha_integrate.train.model import GPTConfig, GPT
from alpha_integrate.train.config.train_integrate import block_size
from alpha_integrate.synthetic_data.tokenization import tokenize, detokenize
from alpha_integrate.synthetic_data.tokenization import WORD2ID, START, PAD, END, SUBEXPR, RULE, PARAM1, PARAM2
from alpha_integrate.synthetic_data.expression_tokenizer import ExpressionTokenizer
from alpha_integrate.synthetic_data.params.step_params import RULE_TOKENS, PARAM2_RULES
from alpha_integrate.synthetic_data.params.tokenizer_params import INTEGRAL, SYMBOLS
from alpha_integrate.synthetic_data.process_action import apply_step
from alpha_integrate.train.timeout_error import TimeoutErr

# -----------------------------------------------------------------------------
e = ExpressionTokenizer()

@timeout_decorator.timeout(1)
def seq_to_sp_timeout(seq):
    return e.seq_to_sp(seq)

model = 'prim_fwdibp'
save_results = False

# Define the expressions to evaluate
x = sp.Symbol('x')
expr = 1 + 2*sp.cos(2*x)/sp.sqrt(sp.sin(2*x)**2 + 1)
exprs = [sp.Integral(expr, x)]
sp_exprs = exprs
exprs = [e.sp_to_seq(f) for f in exprs]
print(exprs)
print(sp_exprs)

# -----------------------------------------------------------------------------
init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'alpha_integrate/train/out-integrate' # ignored if init_from is not 'resume'
start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 1000 # number of samples to draw
temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 15 # retain only the top_k most likely tokens, clamp others to have 0 probability
beam_size = 5
seed = 1337
DEBUG = True
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
exec(open('alpha_integrate/train/configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# model
if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = os.path.join(out_dir, f'{model}.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
elif init_from.startswith('gpt2'):
    # init from a given GPT-2 model
    model = GPT.from_pretrained(init_from, dict(dropout=0.0))

model.eval()
model.to(device)
if compile:
    model = torch.compile(model) # requires PyTorch 2.0 (optional)

VOCAB_SIZE = len(WORD2ID)
START_TOKEN = WORD2ID[START]
END_TOKEN = WORD2ID[END]
PAD_TOKEN = WORD2ID[PAD]
SUBEXPR_TOKEN = WORD2ID[SUBEXPR]
RULE_TOKEN = WORD2ID[RULE]
PARAM1_TOKEN = WORD2ID[PARAM1]
PARAM2_TOKEN = WORD2ID[PARAM2]
INTEGRAL_TOKEN = WORD2ID[INTEGRAL]
X_TOKEN = WORD2ID[SYMBOLS['x']]

def integration_variable_names(sp_expr: sp.Expr) -> dict:

    if sp_expr.is_Atom:
        return dict()
    
    names = dict()
    if isinstance(sp_expr, sp.Integral):
        for phi in sp_expr.args[1]:
            names[phi.name] = True
        for key, value in integration_variable_names(sp_expr.args[0]).items():
            names[key] = True
        return names

    for arg in sp_expr.args:
        for key, value in integration_variable_names(arg).items():
            names[key] = True
    
    return names

def parse_sample(sample: List[str]):

    if len(sample) < 1 or sample[-1] != END:
        return None

    try:
        rule_idx = sample.index(RULE)
    except ValueError:
        return None
    
    subexpr = sample[:rule_idx]
    try:
        subexpr_sp = seq_to_sp_timeout(subexpr)
    except:
        return None

    remaining = sample[rule_idx+1:]
    
    if len(remaining) < 2:
        return None

    rule = remaining[0]

    if rule not in RULE_TOKENS:
        return None
    
    if rule not in PARAM2_RULES:
        if remaining[1] == END:
            return subexpr_sp, (rule,)
        else:
            return None
    
    if remaining[1] != PARAM1:
        return None

    remaining = remaining[2:]
    
    try:
        param2_idx = remaining.index(PARAM2)
    except ValueError:
        return None
    
    param1 = remaining[:param2_idx]
    try: 
        param1_sp = seq_to_sp_timeout(param1)
    except:
        return None

    remaining = remaining[param2_idx+1:-1]

    try:
        param2_sp = seq_to_sp_timeout(remaining)
    except:
        return None

    return subexpr_sp, (rule, param1_sp, param2_sp)


def try_to_integrate(expression: List[str], model, block_size, device, beam_size, 
    temperature, top_k, iters = 40, expr_dict = dict(), variable_list = [], timeout_time = None):

    if timeout_time is not None and time.time() > timeout_time:
        raise TimeoutErr
    if DEBUG:
        logging.info("Called")
    #logging.info(f"Called on expression: {expression} with iter {iters}")
    if len(expression) > (block_size // 2) - 12:
        return [], False, expr_dict

    if iters == 0:
        return [], False, expr_dict

    try:
        sp_expr = seq_to_sp_timeout(expression)
        #sp_expr = sp.Integral(sp_expr, sp.Symbol('x'))
    except:
        return [], False, expr_dict

    if len(variable_list) == 0:
        for s in sp_expr.free_symbols:
            variable_list.append((s.name, s))
    
    expr_dict[sp_expr] = True        
    if DEBUG:
        logging.info(f"Iter: {iters}, Expression: {sp_expr}")
    tokenized_expression = [START_TOKEN] + [WORD2ID[w] for w in expression] + [SUBEXPR_TOKEN]
    x = torch.tensor(tokenized_expression, dtype=torch.long, device=device)

    with torch.no_grad():
        with ctx:
            samples = model.generate_beam(x, block_size - len(tokenized_expression) - 1, beam_size, END_TOKEN, PAD_TOKEN, temperature=temperature, top_k=top_k).cpu()
            next_expressions = []
            subexpressions = []
            rules = []
            variable_lists = []
            for i,y in enumerate(samples):
                #print(y.tolist())
                sample = y.tolist()[len(tokenized_expression):]
                try:
                    endtoken_idx = sample.index(END_TOKEN)
                except ValueError:
                    continue
                sample = sample[:endtoken_idx+1]
                sample = detokenize(sample)
                #logging.info(f"Sample {i}: {sample}")
                parsed = parse_sample(sample)
                if parsed is None:
                    continue
                subexpr, rule = parsed
                try:
                    next_exprs = apply_step((sp_expr, subexpr, rule))
                except:
                    continue
                if next_exprs is None:
                    continue

                free_symbol_names = [phi.name for phi in sp_expr.free_symbols] 
                variable_list_names = [varchange[0] for varchange in variable_list]

                if rule[0] == 'URule':
                    u_var = rule[1]
                    u_func = rule[2]
                    if not isinstance(u_var, sp.Symbol):
                        continue
                    if u_var.name in variable_list_names or u_var.name in free_symbol_names:
                        continue
                    new_variable_list = variable_list.copy()
                    new_variable_list.append((u_var.name, u_func))
                else:
                    new_variable_list = variable_list
                
                if DEBUG:
                    logging.info(f"Rule: {rule}")
                    logging.info(f"Subexpr: {subexpr}")

                for i, nexpr in enumerate(next_exprs):
                    if nexpr not in next_expressions:
                        if DEBUG:
                            logging.info(f"Next expr {i}: {nexpr}")
                        next_expressions.append(nexpr)
                        subexpressions.append(subexpr)
                        rules.append(rule)
                        variable_lists.append(new_variable_list)                
                

                    #logging.info(f"Next exprs type: {type(next_expressions)}")
                    #for a in next_expressions:
                    #    logging.info(f"Next expr: {a}")
                    #logging.info(f"Next Expressions from this: {next_exprs}")
                    #logging.info('-----------------')

    
    if DEBUG:
        logging.info(f"Next Expressions: {next_expressions}")
        logging.info('-----------------')

    #print(variable_lists)
    # check if any of the next_expressions is solved, i.e. does not have integral in it
    solutions = []
    solved_arr = [not next_expr.has(sp.Integral) for next_expr in next_expressions]
    solved = any(solved_arr)
    for i, next_expr in enumerate(next_expressions):
        if DEBUG:
            logging.info(f"Next expr {i}: {next_expr}")

        # in each next_expr, we backsubstitute all variables that are not integration variables anymore

        if DEBUG:
            logging.info(f"Before Substitution: {next_expr}")

        new_variable_list = variable_lists[i].copy()
        num_removed = 0
        for varchange in variable_lists[i][::-1]:
            int_symbol_names = list(integration_variable_names(next_expr).keys())
            free_symbol_names = [phi.name for phi in next_expr.free_symbols]
            stop = True
            if varchange[0] in free_symbol_names:
                if varchange[0] not in int_symbol_names:
                    next_expr = next_expr.subs(varchange[0], varchange[1])
                    stop = False
                    num_removed += 1
            if stop:
                break
        
        new_variable_list = new_variable_list[:len(new_variable_list) - num_removed]
        variable_lists[i] = new_variable_list
        
        if DEBUG:
            logging.info(f"After Substitution: {next_expr}")
            logging.info(f"Variable List: {variable_lists[i]}")
            logging.info('-----------------')

        this_step = [sp_expr, subexpressions[i], rules[i], next_expr]

        if solved_arr[i]:
            #print(f"Solved: {next_expr}")
            solutions.append([this_step])
            
        elif not solved:
            try:
                next_expr_ls = e.sp_to_seq(next_expr)
            except:
                continue

            if next_expr in expr_dict:
                continue
            expr_dict[next_expr] = True
            try:
                next_steps, next_solved, expr_dict = try_to_integrate(next_expr_ls, model, block_size, device, 
                beam_size, temperature, top_k, iters-1, expr_dict, variable_lists[i], timeout_time)
                if DEBUG:
                    logging.info(f"back to {sp_expr}")
            except TimeoutErr:
                return [], False, expr_dict
            if next_solved:
                solved = True
                for solution in next_steps:
                    full_solution = [this_step] + solution
                    solutions.append(full_solution)

    return solutions, solved, expr_dict

# -----------------------------------------------------------------------------
# Run the evaluation

results_ls = []

correct = 0
for i in range(len(exprs)):
    r = dict()
    expr = exprs[i]
    sp_expr = sp_exprs[i]
    r['Expression'] = sp_expr.args[0]
    logging.info(f"Integral: {sp_expr}")
    try:
        now = time.time()
        timeout_time = now + 20
        solutions, solved, expr_dict = try_to_integrate(expr, model, block_size, device, beam_size, temperature, top_k, timeout_time = timeout_time)
        if solved:
            correct += 1
    except TimeoutErr:
        logging.info('Timed out')
        solved = False
    except:
        logging.info('Error')
        solved = False
    
    if solved:
        logging.info(f"Solution: {solutions[0][-1][-1]}")
        logging.info('-----------------')
        r['Transformer_Success'] = True
        r['Transformer_Result'] = solutions[0][-1][-1]
        r['Transformer_Steps'] = solutions
    else:
        logging.info(f'Transformer could not solve.')
    
    results_ls.append(r)

    #logging.info(f"Success rate of sympy: {(sp_success/(i+1)):.3%}")
    logging.info(f"Accuracy of transformer: {(correct/(i+1)):.3%}")
    logging.info('-----------------')

if save_results:

    # save the results to file named results/run_integral_{date_time_idx}.pkl
    os.makedirs('alpha_integrate/train/results', exist_ok=True)
    save_path = f'alpha_integrate/train/results/run_integral_{date_time_idx}.pkl'

    with io.open(save_path, mode='wb') as f:
        pickle.dump(results_ls, f)
        