
import os 
import re 
import numpy as np
from IPython.display import Math
import sympy as sp 
from transformers import PreTrainedTokenizerFast
from time import time 
from tqdm import tqdm
import torch
from joblib import Parallel, delayed
from torch import FloatTensor, LongTensor

def cmp_generation_target(pred, target):
    
    print('============')
    print(pred)
    print(target)
    print('------------')
    print('', flush=True)

    return pred == target


# def eval_prediction(model, dataloader, tokenizer=None, num_beams=1, use_tqdm=False, one_step=False):
#     # os.environ['TOKENIZERS_PARALLELISM'] = 'false'
    
#     hits_list = []
#     acc, tot = 0, 0
#     start_t = time()
#     with torch.no_grad():
#         iterator = tqdm(dataloader) if use_tqdm else dataloader
#         for i, batch in enumerate(iterator):
#             for k in batch: batch[k] = batch[k].cuda()
#             x, y = batch['input_ids'], batch['decoder_input_ids']
            
#             # output_ids = model.generate(x, max_length=y.shape[-1], num_beams=num_beams, do_sample=False)
#             quantize_fn = None
#             output = model.generate(input_ids               = batch['input_ids'], 
#                                         input_continuous_labels = batch['input_continuous_labels'],
#                                         attention_mask          = batch['attention_mask'], 
#                                         continuous_token_ids    = [],
#                                         max_length              = y.shape[-1],
#                                         quantize_fn             = quantize_fn,)
#             # z_text = model.decode_prediction(prediction, tokenizer, skip_special_tokens=True)
#             output_ids = output['prediction']

#             l = min(y.shape[-1], output_ids.shape[-1])
#             hits = torch.all(y[:, :l] == output_ids[:, :l], dim=1).cpu()
#             hits_list.append(hits)
#             acc += sum(hits).item()
#             tot += len(hits)
            
#             if one_step: break

#         acc /= tot 
#     # os.environ['TOKENIZERS_PARALLELISM'] = 'true'
#     runtime = time() - start_t
#     return {'acc': acc, 'hits': hits_list, 'num_beams': num_beams, 'num_samples': tot, 'runtime': runtime}



def batch_eval_prediction(model, batch, tokenizer, quantize_fn=None, continuous_coefficient=False, coefficient_token_id=None):
    # move devices
    for k in batch: batch[k] = batch[k].cuda()

    max_length = batch['decoder_input_ids'].shape[-1]

    # generate
    output = model.generate(input_ids               = batch['input_ids'], 
                            input_continuous_labels = batch['input_continuous_labels'],
                            attention_mask          = batch['attention_mask'], 
                            continuous_token_ids    = [coefficient_token_id],
                            max_length              = max_length + 1,
                            quantize_fn             = quantize_fn,)
    pred_texts = model.postprocess_prediction(output, tokenizer, skip_special_tokens=True)['prediction_texts']
    
    return pred_texts

@torch.no_grad()
def eval_prediction(model, dataloader, tokenizer, use_tqdm=False, steps=None, continuous_coefficient=False, quantize_fn=None):
    hits_list = []
    acc, tot = 0, 0
    start_t = time()

    coefficient_token_id = tokenizer.vocab['[C]']

    iterator = tqdm(dataloader) if use_tqdm else dataloader
    for i, batch in enumerate(iterator):

        pred_texts = batch_eval_prediction(model, batch, tokenizer, quantize_fn=quantize_fn, continuous_coefficient=continuous_coefficient, coefficient_token_id=coefficient_token_id)
        y_texts = tokenizer.batch_decode(batch['decoder_input_ids'], skip_special_tokens=True)

        hits = [p == y for p, y in zip(pred_texts, y_texts)]
        
        hits_list.append(hits)
        acc += sum(hits)# .item()
        tot += len(hits)
        
        if steps is not None and i > steps: break

    acc /= tot 
    # os.environ['TOKENIZERS_PARALLELISM'] = 'true'
    runtime = time() - start_t
    return {'acc': acc, 'hits': hits_list, 'num_samples': tot, 'runtime': runtime}


def get_checkpoint_id(save_dir):
    cpt_file = [f for f in os.listdir(save_dir) if 'checkpoint' in f][0]
    cpid = int(re.search(r'checkpoint-(\d+)', cpt_file).group(1))
    return cpid 

def load_items(save_dir, data_path):
    model, trainset, testset, tokenizer, representation, integer_encoding, num_vars = None, None, None, None, None, None, None
    if save_dir:
        # set up tokenizer
        cpid = get_checkpoint_id(save_dir)
        tokenizer = PreTrainedTokenizerFast.from_pretrained(os.path.join(save_dir, f'checkpoint-{cpid}'))
        tokenizer.add_special_tokens({'pad_token': '[PAD]', 'cls_token': '[CLS]', 'bos_token': '<s>', 'eos_token': '</s>'})
        from tokenizers.processors import TemplateProcessing
        tokenizer.post_processor = TemplateProcessing(
            single="<s> $A </s>",
            special_tokens=[("<s>", tokenizer.bos_token_id), ("</s>", tokenizer.eos_token_id)],
        )

        # load model 
        config = BartConfig.from_pretrained(os.path.join(save_dir, f'checkpoint-{cpid}/config.json'))
        model = BartForConditionalGeneration.from_pretrained(os.path.join(save_dir, f'checkpoint-{cpid}/pytorch_model.bin'), config=config)
        model.eval().cuda()

    if data_path:
        # load dataset
        trainset, testset = load_dataset(data_path)
        dataset_name = data_path.split('/')[-1]
        _, _, num_vars, num_terms, degree, rsf = re.findall(r'\d+', dataset_name)
        num_vars = int(num_vars)
        representation = re.findall(r'_(infix|prefix)', dataset_name)[0]
        integer_encoding = re.findall(r'_(complete|none|signed)', dataset_name)[0]
    
    return model, trainset, testset, tokenizer, representation, integer_encoding, num_vars

def evaluate(model, dataset, tokenizer, representation, integer_encoding, num_vars, n_test=100, n_show=10, max_length=50):
    gens = sp.symbols(f'x:{num_vars}')
    
    k, l = 0, 0
    n_invalid = 0
    invalid = False 
    runtime = 0
    for i in range(n_test):
        input_text, target_text = dataset[i]['data'], dataset[i]['target']
        input_ids = tokenizer.encode(input_text, return_tensors='pt')
        input_ids = input_ids.cuda()
        
        s = time()
        output_ids = model.generate(input_ids, max_length=max_length, num_beams=1, do_sample=False)
        runtime += time() - s 
        output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        P, G, Q = [], [], []
        for q_text in output_text.split('[SEP] '):
            if q_text == '': continue
            try:
                q = text_to_sympy(q_text, representation=representation, integer_encoding=integer_encoding)
                Q.append(q)
            except: 
                invalid = True 
                print('--- invalid expr --------------')
                print(q_text)
                print(output_text)
                print(output_text.split('[SEP]'))
                print('-------------------------------')
                break

        if invalid: 
            n_invalid += 1
            invalid = False 
            continue 

        for p_text in input_text.split('[SEP] '):
            p = text_to_sympy(p_text, representation=representation, integer_encoding=integer_encoding)
            P.append(p)
        for g_text in target_text.split('[SEP] '):
            g = text_to_sympy(g_text, representation=representation, integer_encoding=integer_encoding)
            G.append(g)
    
        if num_vars < 4: 
            P = [sp.Poly(p, gens=gens).subs(list(zip(gens, sp.symbols('x y z')))).as_expr() for p in P]
            Q = [sp.Poly(q, gens=gens).subs(list(zip(gens, sp.symbols('x y z')))).as_expr() for q in Q]
            G = [sp.Poly(g, gens=gens).subs(list(zip(gens, sp.symbols('x y z')))).as_expr() for g in G]
            
        hits = [g == q for g, q in zip(G, Q)]
        hit = len(Q) == len(G) and np.all(hits)
        k += int(hit)
    
    N = n_test - n_invalid
    acc = 100* k / N
    print(f'acc = {100*k/N:>5.1f}% ({k:04} / {N})  [{runtime:.4f} sec] --- ({n_invalid} invalids)')
    
    return acc, runtime


def evaluate_sympy_runtime(task, dataset, representation, integer_encoding, num_vars, n_test=100):
    gens = sp.symbols(f'x:{num_vars}')
    
    k, l = 0, 0
    n_invalid = 0
    runtime = 0
    for i in range(n_test):
        input_text, target_text = dataset[i]['data'], dataset[i]['target']

        P, G, Q = [], [], []
        for p_text in input_text.split('[SEP]'):
            p = text_to_sympy(p_text, representation=representation, integer_encoding=integer_encoding)
            P.append(p)
        for g_text in target_text.split('[SEP]'):
            g = text_to_sympy(g_text, representation=representation, integer_encoding=integer_encoding)
            G.append(g)

        s = time()
        if task == 'Factorization': 
            Q = [sp.factor(P[0])]
        if task == 'GCD': 
            Q = [sp.gcd(*P)]
        if task == 'Groebner': 
            Q = sp.groebner(P, *gens, order='grevlex')
        runtime += time() - s
        
        if num_vars < 4: 
            P = [sp.Poly(p, gens=gens).subs(list(zip(gens, sp.symbols('x y z')))).as_expr() for p in P]
            Q = [sp.Poly(q, gens=gens).subs(list(zip(gens, sp.symbols('x y z')))).as_expr() for q in Q]
            G = [sp.Poly(g, gens=gens).subs(list(zip(gens, sp.symbols('x y z')))).as_expr() for g in G]
            
        hits = [g == q in zip(G, Q)]
        hit = len(Q) == len(G) and np.all(hits)
        k += int(hit)
    
    N = n_test - n_invalid
    acc = 100* k / N
    print(f'acc = {100*k/N:>5.1f}% ({k:04} / {N})  [{runtime:.4f} sec]')
    
    return acc, runtime


def _dataset_profile_factorization(target_text, representation, integer_encoding):
    G = []
    for g_text in target_text.split('[SEP]'):
        g = text_to_sympy(g_text, representation=representation, integer_encoding=integer_encoding)
        G.append(g)
    
    c, factor_list = sp.factor_list(G[0])
    n_factor = int(c != 1 or np.any([f[0].is_Atom for f in factor_list]) ) + len([f for f in factor_list if not f[0].is_Atom])
    n_factor = 1 if n_factor == 0 else n_factor

    return n_factor

    
def dataset_profile_factorization(dataset, tokenizer, representation, integer_encoding, num_vars):
    gens = sp.symbols(f'x:{num_vars}')
    
    k, l = 0, 0
    n_invalid = 0
    runtime = 0
    profile = {}
    N = len(dataset)
    from tqdm import tqdm
    
    n_factors = joblib.Parallel(n_jobs=-1, backend="loky", verbose=True)(joblib.delayed(_dataset_profile_factorization)(dataset[i]['target'], representation, integer_encoding) for i in range(N))     

    profile = {}
    for nf in n_factors: 
        if nf in profile:
            profile[nf] += 1
        else: 
            profile[nf] = 1
            
    keys = profile.keys()
    for n_factor in sorted(keys):
        k = profile[n_factor]
        print(f'[{n_factor}] {k:04} / {N:04}')
    
    return profile, N, runtime


def main(task_names):
    accs = {}
    accs_s = {} 
    n_test = 2000

    task_name = 'Groebner'
    if task_name in task_names:
        print(task_name)
        data_path = 'data4/smallgroebner_1000000_10000_n=3_m=4_d=2_s=3_prefix_none'
        save_dir = 'results/smallgroebner_prefix_none_fixed/'
        model, trainset, testset, tokenizer, representation, integer_encoding, num_vars = load_items(save_dir, data_path)    
        
        print('#transformer')
        print('##train')
        train_acc, runtime = evaluate(model, trainset, tokenizer, representation, integer_encoding, num_vars, n_test=n_test, max_length=1024)
        print('##test')
        test_acc, runtime = evaluate(model, testset, tokenizer, representation, integer_encoding, num_vars, n_test=n_test, max_length=1024)
        accs[task_name] = test_acc
        print('#sympy')
        test_acc, runtime = evaluate_sympy_runtime(task_name, testset, representation, integer_encoding, num_vars, n_test=n_test)
        accs_s[task_name] = test_acc
        
    task_name = 'Factorization'
    if task_name in task_names:
        print(task_name)
        data_path = 'data4/factorization_1000000_2000_n=3_m=3_d=5_r=3_prefix_none'
        save_dir = 'results/factorization_prefix_none_fixed/'
        model, trainset, testset, tokenizer, representation, integer_encoding, num_vars = load_items(save_dir, data_path)    

        print('#transformer')
        print('##train')
        train_acc, runtime = evaluate(model, trainset, tokenizer, representation, integer_encoding, num_vars, n_test=n_test, max_length=1024)
        print('##test')

        test_acc, runtime = evaluate(model, testset, tokenizer, representation, integer_encoding, num_vars, n_test=n_test, max_length=1024)
        accs[task_name] = test_acc
        print('#sympy')
        test_acc, runtime = evaluate_sympy_runtime(task_name, testset, representation, integer_encoding, num_vars, n_test=n_test)
        accs_s[task_name] = test_acc

    task_name = 'GCD'
    if task_name in task_names:
        print(task_name)
        data_path = 'data4/gcd_1000000_2000_n=3_m=5_d=5_rsf=1_prefix_none'
        save_dir = 'results/gcd_prefix_none_fixed/'
        model, trainset, testset, tokenizer, representation, integer_encoding, num_vars = load_items(save_dir, data_path)    

        print('#transformer')
        print('##train')
        train_acc, runtime = evaluate(model, trainset, tokenizer, representation, integer_encoding, num_vars, n_test=n_test, max_length=1024)
        print('##test')
        test_acc, runtime = evaluate(model, testset, tokenizer, representation, integer_encoding, num_vars, n_test=n_test, max_length=1024)
        accs[task_name] = test_acc
        print('#sympy')
        test_acc, runtime = evaluate_sympy_runtime(task_name, testset, representation, integer_encoding, num_vars, n_test=n_test)
        accs_s[task_name] = test_acc
        
        print(accs)
        print(accs_s)
        
    task_name = 'Hilbert'
    if task_name in task_names:
        print(task_name)
        # data_path = 'data5/hilbert_1000000_10000_n=3_m=10_d=5_s=5_prefix_none'
        data_path = 'data5/hilbert_1000000_10000_n=3_m=5_d=5_s=5_prefix_none'
        save_dir = 'results/hilbert_prefix_none_fixed/'
        model, trainset, testset, tokenizer, representation, integer_encoding, num_vars = load_items(save_dir, data_path)    

        print('#transformer')
        # print('##train')
        # train_acc, runtime = evaluate(model, trainset, tokenizer, representation, integer_encoding, num_vars, n_test=n_test, max_length=1024)
        print('##test')
        test_acc, runtime = evaluate(model, testset, tokenizer, representation, integer_encoding, num_vars, n_test=n_test, max_length=1024)
        accs[task_name] = test_acc
        # print('#sympy')
        # test_acc, runtime = evaluate_sympy_runtime(task_name, testset, representation, integer_encoding, num_vars, n_test=n_test)
        accs_s[task_name] = -1
        
    print(accs)
    print(accs_s)
    
def main2():
    task_name = 'Factorization'
    print(task_name)
    data_path = 'data4/factorization_1000000_2000_n=3_m=3_d=5_r=3_prefix_none'
    save_dir = 'results/factorization_prefix_none_fixed/'
    cpid = 156255
    model, trainset, testset, tokenizer, representation, integer_encoding, num_vars = load_items(save_dir, cpid, data_path)    
    
    print(data_path)
    print('train set')
    dataset_profile_factorization(trainset, tokenizer, representation, integer_encoding, num_vars)
    print('test set')
    dataset_profile_factorization(testset, tokenizer, representation, integer_encoding, num_vars)


def main_hilbert(model_path, data_path_list, n_test=2000):
    task_name = 'Hilbert'
    save_dir = model_path

    print(task_name)
    for data_path in data_path_list:
        # data_path = 'data5/hilbert_1000000_10000_n=3_m=5_d=5_s=5_prefix_none'
        model, trainset, testset, tokenizer, representation, integer_encoding, num_vars = load_items(save_dir, data_path)    

        print(f'--- {data_path} ------------')
        # print('##train')
        # train_acc, runtime = evaluate(model, trainset, tokenizer, representation, integer_encoding, num_vars, n_test=n_test, max_length=1024)
        print('##test')
        test_acc, runtime = evaluate(model, testset, tokenizer, representation, integer_encoding, num_vars, n_test=n_test, max_length=1024)
        # accs[task_name] = test_acc
        # print('#sympy')
        # test_acc, runtime = evaluate_sympy_runtime(task_name, testset, representation, integer_encoding, num_vars, n_test=n_test)
        # accs_s[task_name] = -1

            

if __name__ == '__main__':
    # task_names = ['Groebner', 'GCD', 'Factorization', 'Hilbert']
    # task_names = ['Hilbert']
    # main(task_names)
    # main2()
    
    # model_path = 'results/hilbert-L_prefix_none_fixed/'
    # data_path_list = ['data5/hilbert_1000000_10000_n=3_m=5_d=5_s=5_prefix_none', 
    #                   'data5/hilbert_1000000_10000_n=3_m=10_d=5_s=5_prefix_none', 
    #                   'data5/hilbert_1000000_10000_n=3_m=10_d=5_s=8_prefix_none' ]
    # main_hilbert(model_path, data_path_list, 2000)
    main()