import sympy as sp
from sympy import *
from sympy.abc import a, x, u
from sympy import Integral, cos, sqrt, Poly, erfi
from sympy import Symbol, exp, log, tan, sin, pi, diff
from sympy.parsing.sympy_parser import parse_expr

import sys
import pandas as pd
import numpy as np
from tqdm import tqdm
import re
import random
import csv
import numpy
import torch
import gc
import sys
from joblib import Parallel, delayed
import signal

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn import metrics
from sklearn.preprocessing import LabelEncoder

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import TensorDataset
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
import torch.optim as optim
from transformers import AdamW, get_linear_schedule_with_warmup

from sympy.integrals.manualintegrate import (
    _manualintegrate, integral_steps, evaluates,
    ConstantRule, ConstantTimesRule, PowerRule, AddRule, URule,
    PartsRule, CyclicPartsRule, TrigRule, ExpRule, ReciprocalRule, ArctanRule,
    AlternativeRule, DontKnowRule, RewriteRule
)

import sys
sys.path.insert(0, 'SymbolicMathematics')

from src.envs.char_sp import CharSPEnvironment
import os
import numpy as np
import sympy as sp
import torch

from src.envs.sympy_utils import remove_root_constant_terms, reduce_coefficients, reindex_coefficients
from src.utils import AttrDict
from src.envs import build_env
from src.model import build_modules

from src.utils import to_cuda
from src.envs.sympy_utils import simplify

params = params = AttrDict({

    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 384,
    'max_int': 11,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'reload_data': "prim_fwd.test",
    'rewrite_functions': '',
    'tasks': 'prim_fwd',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,erfi:1,erf:1,erfc:1,erfinv:1,erfcinv:1,expint:1,ei:1,li:1,si:1,ci:1,shi:1,chi:1,fresnelc:1,fresnels:1,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',

})

env = build_env(params)
x = env.local_dict['x']

class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException

signal.signal(signal.SIGALRM, timeout_handler)

def call_with_multiprocessing(target_function, parameters, n_jobs=-2, backend='loky'):
    results = Parallel(n_jobs=n_jobs, backend=backend)(delayed(target_function)(**param_dict) for param_dict in tqdm(parameters))
    return results

def filter_data(index):
    signal.signal(signal.SIGALRM, timeout_handler)
    try:
        signal.alarm(300)
        in_seq = env.sympy_to_prefix(sp.sympify(out_data.function.values[index]))
        signal.alarm(0)
        if len(in_seq) > 383:
            return index
    except:
        return index
    else:
        if len(out_data.rules.values[index].split()) == 2:
            try:
                signal.alarm(300)
                rexpr = re.search(", (.+?),", out_data.rules.values[index]).group(1)
                r = re.search("(.+?)'", out_data.rules.values[index]).group(1)
                rawout_seq = [r] + env.sympy_to_prefix(sp.sympify(rexpr))
                signal.alarm(0)
                if len(rawout_seq) > 29:
                    return index
            except:
                return index
        else:
            try:
                signal.alarm(300)
                r = re.search("(.+?)'", out_data.rules.values[index]).group(1)
                signal.alarm(0)
            except:
                return index

out_data = pd.read_csv('expr_27465168.csv')

ind = Parallel(n_jobs=-2, backend='multiprocessing')(delayed(filter_data)(**{"index": i}) for i in tqdm(range(len(out_data))))
print(len(ind))
print(ind[0], ind[-1])
ind = [x for x in ind if x is not None]
print(ind[0], ind[-1])
ind.sort()
print(ind[0], ind[-1])

out_data = out_data.drop(out_data.index[ind])

out_data.reset_index(drop=True, inplace=True)

SPECIAL_WORDS = ['<pad>', '<s>', '</s>']
constants = ['pi', 'E']
variables = ['x', 'y', 'z', 't', 'u', '_u', '_theta']
symbols = ['I', 'INT+', 'INT-', 'INT', 'FLOAT', '-', '.', '10^']
int_base = 60
max_digit = (int_base + 1) // 2
elements = [str(i) for i in range(max_digit - abs(int_base), max_digit)]
OPERATORS = {
    # Elementary functions
    'add': 2,
    'sub': 2,
    'mul': 2,
    'div': 2,
    'pow': 2,
    'rac': 2,
    'inv': 1,
    'pow2': 1,
    'pow3': 1,
    'pow4': 1,
    'pow5': 1,
    'sqrt': 1,
    'exp': 1,
    'ln': 1,
    'abs': 1,
    'sign': 1,
    # Trigonometric Functions
    'sin': 1,
    'cos': 1,
    'tan': 1,
    'cot': 1,
    'sec': 1,
    'csc': 1,
    # Trigonometric Inverses
    'asin': 1,
    'acos': 1,
    'atan': 1,
    'acot': 1,
    'asec': 1,
    'acsc': 1,
    # Hyperbolic Functions
    'sinh': 1,
    'cosh': 1,
    'tanh': 1,
    'coth': 1,
    'sech': 1,
    'csch': 1,
    # Hyperbolic Inverses
    'asinh': 1,
    'acosh': 1,
    'atanh': 1,
    'acoth': 1,
    'asech': 1,
    'acsch': 1,
    'erfi': 1,
    'erf': 1,
    'erfc': 1,
    'erfinv': 1,
    'erfcinv': 1,
    'expint': 1,
    'ei': 1,
    'li': 1,
    'si': 1,
    'ci': 1,
    'shi': 1,
    'chi': 1,
    'fresnelc': 1,
    'fresnels': 1,
    # Derivative
    'derivative': 2,
    # custom functions
    'f': 1,
    'g': 2,
    'h': 3,
}
rules = ['power_rule', 'mul_rule', 'add_rule', 'partial_fractions_rule',
         'exp_rule', 'trig_rule', 'substitution_rule', 'parts_rule',
         'constant_rule', 'quadratic_denom_rule', 'cancel_rule',
         'distribute_expand_rule', 'sqrt_linear_rule',
         'trig_sindouble_rule', 'trig_expand_rule',
         'sqrt_quadratic_rule', 'hyperbolic_rule', 'trig_tansec_rule',
         'trig_sincos_rule', 'inverse_trig_rule', 'trig_cotcsc_rule',
         'trig_substitution_rule', 'special_function_rule',
         'trig_product_rule']

'''rules = ['add_rule', 'constant_rule', 'dont_know_rule', 'exp_rule',
         'inverse_trig_rule', 'mul_rule', 'parts_rule', 'power_rule',
         'quadratic_denom_rule', 'root_mul_rule', 'special_function_rule',
         'steps_rule', 'substitution_rule', 'trig_expand_rule',
         'trig_powers_products_rule', 'trig_product_rule', 'trig_rule',
         'trig_substitution_rule']
n_vocab = 160'''

words = SPECIAL_WORDS + constants + variables + sorted(list(OPERATORS.keys())) + symbols + elements + rules + ['e', 'f', 'g', 'h']
id2word = {i: s for i, s in enumerate(words)}
word2id = {s: i for i, s in id2word.items()}

class SymDataset(Dataset):
    def __init__(self, out_data_train, env, word2id, max_length):
        self.out_data_train = out_data_train
        self.env = env
        self.word2id = word2id
        self.max_length = max_length
        mask = torch.triu(torch.ones(max_length, max_length), 1)
        mask = mask.masked_fill(mask==1, float('-inf'))
        self.mask = mask.reshape(max_length, max_length)
    def __len__(self):
        return len(self.out_data_train)

    def __getitem__(self, index):

        try:
            raw_seq = self.env.sympy_to_prefix(sp.sympify(self.out_data_train.function.values[index]))
        except Exception as e:
            print(self.out_data_train.function.values[index])
            print(e)
            return None
        in_seq = [self.word2id[tok] for tok in raw_seq]
        padding_length = self.max_length - len(in_seq) - 1
        in_seq = in_seq + ([self.word2id['<pad>']] * padding_length) + [self.word2id['</s>']]

        if len(self.out_data_train.rules.values[index].split()) == 2:
            rexpr = re.search(", (.+?),", self.out_data_train.rules.values[index]).group(1)
            r = re.search("(.+?)'", self.out_data_train.rules.values[index]).group(1)
            try:
                rawout_seq = [r] + self.env.sympy_to_prefix(sp.sympify(rexpr))
            except Exception as e:
                print(rexpr)
                print(e)
                return None
            rawout_seq = [self.word2id[tok] for tok in rawout_seq]
        else:
            r = re.search("(.+?)'", self.out_data_train.rules.values[index]).group(1)
            rawout_seq = [r]
            rawout_seq = [self.word2id[tok] for tok in rawout_seq]
        padding_length = 30 - len(rawout_seq) - 1

        dec_seq = [self.word2id['<s>']] + rawout_seq + ([self.word2id['<pad>']] * padding_length)

        out_seq = rawout_seq + [self.word2id['</s>']] + ([self.word2id['<pad>']] * padding_length)

        mask = self.mask

        return torch.tensor(in_seq, dtype=torch.long), torch.tensor(dec_seq, dtype=torch.long), torch.tensor(out_seq, dtype=torch.long), mask

def collate_not_none(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

out_data_train, out_data_eval, _, _ = train_test_split(out_data,
                                                       out_data,
                                                       test_size=0.2,
                                                       random_state=42)

out_data_val, out_data_test, _, _ = train_test_split(out_data_eval,
                                                     out_data_eval,
                                                     test_size=0.5,
                                                     random_state=42)

train_ds =  SymDataset(out_data_train, env, word2id, 384)
val_ds =  SymDataset(out_data_val, env, word2id, 384)
test_ds =  SymDataset(out_data_test, env, word2id, 384)

batch_size = 256

train_dl = DataLoader(train_ds,
                      sampler=RandomSampler(train_ds),
                      batch_size=batch_size,
                      num_workers=20,
                      collate_fn=collate_not_none)

val_dl = DataLoader(val_ds,
                    sampler=SequentialSampler(val_ds),
                    batch_size=batch_size,
                    num_workers=20,
                    collate_fn=collate_not_none)

class vanillatransformer(nn.Module):
    def __init__(self, d_model = 512, num_layers = 6, n_vocab = 166, dim_feedforward = 512, n_head = 8, max_len = 128, device = "cpu", id2word = id2word):
        super().__init__()

        self.tok_emb_enc = nn.Embedding(n_vocab, d_model)
        self.pos_emb_enc = nn.Parameter(torch.zeros(1, max_len, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model, n_head, dim_feedforward= dim_feedforward)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers = num_layers, norm = nn.LayerNorm(d_model))

        self.tok_emb_dec = nn.Embedding(n_vocab, d_model)
        self.pos_emb_dec = nn.Parameter(torch.zeros(1, max_len, d_model))
        decoder_layer = nn.TransformerDecoderLayer(d_model, n_head, dim_feedforward= dim_feedforward)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers = num_layers, norm = nn.LayerNorm(d_model))
        self.head = nn.Linear(d_model, n_vocab)
        #self.init_token = n_vocab - 1
        self.id2word = id2word
        self.device = device



    def forward(self, inseq, outseq, mask):
        B, T_e = inseq.size()
        inseq = self.tok_emb_enc(inseq)
        inseq = self.pos_emb_enc[:, :T_e, :] + inseq
        inseq = self.encoder(inseq.transpose(0, 1)).transpose(0, 1)

        # appended 0:
        B, T_d = outseq.size()
        outseq = self.tok_emb_dec(outseq)
        outseq = self.pos_emb_dec[:, :T_d, :] + outseq

        mask = mask[0][:T_d, :T_d]

        output = self.decoder(tgt = outseq.transpose(0, 1), memory = inseq.transpose(0, 1), tgt_mask = mask ).transpose(0, 1)
        output = self.head(output)
        return output

    def toSOP(self, eq1, mask, gen_len = 63, T = 0.02):

        B, T = eq1.size()
        eq1 = self.tok_emb_enc(eq1)
        eq1 = self.pos_emb_enc[:, :T, :] + eq1
        eq1 = self.encoder(eq1.transpose(0, 1)).transpose(0, 1)
        gens = torch.ones(B, 1) * 1
        gens = gens.long().to(self.device)

        for i in range(1, gen_len + 1):
            eq2 = self.tok_emb_dec(gens)
            eq2 = self.pos_emb_dec[:, :i, :] + eq2
            out= self.decoder(tgt = eq2.transpose(0, 1), memory = eq1.transpose(0, 1), tgt_mask = mask[:i, :i]).transpose(0, 1)
            out = self.head(out)[:, -1, :]

            out = F.softmax(out, dim = -1)
            out = torch.multinomial(out, num_samples=1)
            out = out.reshape(B, 1)
            gens = torch.cat([gens, out], dim = 1)

        gens = gens.tolist()
        #  rets = []
        for eq in gens:
            rets = []
            for i in eq:
                if i==0 or i==2:
                    break
                else:
                    rets.append(self.id2word[i])

        return rets

device = torch.device("cuda")
model = vanillatransformer(d_model = 512, num_layers = 6, n_vocab = 166, dim_feedforward = 1024, n_head = 8, max_len = 384, device = device, id2word = id2word)
model.to(device)
optimizer = optim.AdamW(model.parameters(),
                        lr=4e-5,
                        eps=1e-8)

start_epoch = 1
epochs = 15

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=0,
                                            num_training_steps=len(train_dl)*epochs)
criterion = nn.CrossEntropyLoss()

scaler = torch.cuda.amp.GradScaler()

def eval_loop_fn(data_loader, model, device):
    model.eval()
    pbar = tqdm(data_loader)
    tot_loss = 0
    cnt = 0
    for bi, (x, yin, yout, mask) in enumerate(pbar):
        if len(x) != batch_size:
            print(f"====================Eval None Observed {len(x)}======================")
        x = x.to(device)
        yin = yin.to(device)
        yout = yout.to(device)
        mask = mask.to(device)
        with torch.no_grad():
            y_pred = model(x, yin, mask)
            loss = criterion(y_pred.view(-1, 166), yout.view(-1))
        tot_loss += loss.item()
        cnt += 1
        pbar.set_description(f"current validation loss : {tot_loss/cnt:.5f}")
    print(f'Validation Loss: {tot_loss/cnt :.5f}')
    return tot_loss/cnt

scaler = torch.cuda.amp.GradScaler()
val_loss = sys.maxsize
continue_training = sys.argv[1].lower() == 'true'

if continue_training:
#    checkpoint = torch.load('Multi_gpu_384_symmath.pth')
    checkpoint = torch.load('Multi_gpu_384_symmath_fulldata.pth')
    start_epoch = checkpoint['epoch'] + 1
    print(f'Checkpoint loaded till epoch {start_epoch - 1}')
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    model= nn.DataParallel(model)
else:
    model= nn.DataParallel(model)

print(f'Starting training from Epoch:{start_epoch}')
cntr_early_stopping = 0
for epoch in range(start_epoch, epochs + 1):
    model.train()
    pbar = tqdm(train_dl)
    tot_loss = 0
    cnt = 0
    for (x, yin, yout, mask) in pbar:
        if len(x) != batch_size:
            print(f"====================Train None Observed {len(x)}======================")
        x = x.to(device)
        yin = yin.to(device)
        yout = yout.to(device)
        mask = mask.to(device)
        model.zero_grad()
        with torch.cuda.amp.autocast():
            y_pred = model(x, yin, mask)
            loss = criterion(y_pred.view(-1, 166), yout.view(-1))

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        tot_loss += loss.item()
        cnt += 1
        pbar.set_description(f"current loss : {tot_loss/cnt:.5f}")
    print(f'Epoch {epoch} : Loss : {tot_loss/cnt :.5f}')
    currval_loss = eval_loop_fn(val_dl, model, device)
    if currval_loss < val_loss:
        cntr_early_stopping = 0
        val_loss = currval_loss
        checkpoint = {
            'epoch': epoch,
            'state_dict': model.module.state_dict(),
            'optimizer': optimizer.state_dict()
        }
#        torch.save(checkpoint, "Multi_gpu_384_symmath.pth")
        torch.save(checkpoint, "Multi_gpu_384_symmath_fulldata.pth")
        print('Model Saved!')
    else:
        cntr_early_stopping = cntr_early_stopping +  1
    if cntr_early_stopping == 5:
        print(f'Early Stopping at Epoch {epoch}')
        break
