import sympy as sp
from sympy import *
from sympy.abc import a, x, u
from sympy import Integral, cos, sqrt, Poly, erfi
from sympy import Symbol, exp, log, tan, sin, pi, diff
from sympy.parsing.sympy_parser import parse_expr

import sys
import pandas as pd
import numpy as np
from tqdm import tqdm
import re
import random
import csv
import numpy
import torch
import gc
import sys

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn import metrics
from sklearn.preprocessing import LabelEncoder

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import TensorDataset
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
import torch.optim as optim
from transformers import AdamW, get_linear_schedule_with_warmup

import sys
sys.path.insert(0, 'SymbolicMathematics')

from src.envs.char_sp import CharSPEnvironment
import os
import numpy as np
import sympy as sp
import torch

from src.envs.sympy_utils import remove_root_constant_terms, reduce_coefficients, reindex_coefficients
from src.utils import AttrDict
from src.envs import build_env
from src.model import build_modules

from src.utils import to_cuda
from src.envs.sympy_utils import simplify

params = params = AttrDict({

    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 384,
    'max_int': 11,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'reload_data': "prim_fwd.test",
    'rewrite_functions': '',
    'tasks': 'prim_fwd',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,erfi:1,erf:1,erfc:1,erfinv:1,erfcinv:1,expint:1,ei:1,li:1,si:1,ci:1,shi:1,chi:1,fresnelc:1,fresnels:1,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',

})

env = build_env(params)
x = env.local_dict['x']

SPECIAL_WORDS = ['<pad>', '<s>', '</s>']
constants = ['pi', 'E']
variables = ['x', 'y', 'z', 't', 'u', '_u', '_theta']
symbols = ['I', 'INT+', 'INT-', 'INT', 'FLOAT', '-', '.', '10^']
int_base = 60
max_digit = (int_base + 1) // 2
elements = [str(i) for i in range(max_digit - abs(int_base), max_digit)]
OPERATORS = {
    # Elementary functions
    'add': 2,
    'sub': 2,
    'mul': 2,
    'div': 2,
    'pow': 2,
    'rac': 2,
    'inv': 1,
    'pow2': 1,
    'pow3': 1,
    'pow4': 1,
    'pow5': 1,
    'sqrt': 1,
    'exp': 1,
    'ln': 1,
    'abs': 1,
    'sign': 1,
    # Trigonometric Functions
    'sin': 1,
    'cos': 1,
    'tan': 1,
    'cot': 1,
    'sec': 1,
    'csc': 1,
    # Trigonometric Inverses
    'asin': 1,
    'acos': 1,
    'atan': 1,
    'acot': 1,
    'asec': 1,
    'acsc': 1,
    # Hyperbolic Functions
    'sinh': 1,
    'cosh': 1,
    'tanh': 1,
    'coth': 1,
    'sech': 1,
    'csch': 1,
    # Hyperbolic Inverses
    'asinh': 1,
    'acosh': 1,
    'atanh': 1,
    'acoth': 1,
    'asech': 1,
    'acsch': 1,
    'erfi': 1,
    'erf': 1,
    'erfc': 1,
    'erfinv': 1,
    'erfcinv': 1,
    'expint': 1,
    'ei': 1,
    'li': 1,
    'si': 1,
    'ci': 1,
    'shi': 1,
    'chi': 1,
    'fresnelc': 1,
    'fresnels': 1,
    # Derivative
    'derivative': 2,
    # custom functions
    'f': 1,
    'g': 2,
    'h': 3,
}
rules = ['power_rule', 'mul_rule', 'add_rule', 'partial_fractions_rule',
         'exp_rule', 'trig_rule', 'substitution_rule', 'parts_rule',
         'constant_rule', 'quadratic_denom_rule', 'cancel_rule',
         'distribute_expand_rule', 'sqrt_linear_rule',
         'trig_sindouble_rule', 'trig_expand_rule',
         'sqrt_quadratic_rule', 'hyperbolic_rule', 'trig_tansec_rule',
         'trig_sincos_rule', 'inverse_trig_rule', 'trig_cotcsc_rule',
         'trig_substitution_rule', 'special_function_rule',
         'trig_product_rule']

words = SPECIAL_WORDS + constants + variables + sorted(list(OPERATORS.keys())) + symbols + elements + rules + ['e', 'f', 'g', 'h']
id2word = {i: s for i, s in enumerate(words)}
word2id = {s: i for i, s in id2word.items()}

class vanillatransformer(nn.Module):
    def __init__(self, d_model = 512, num_layers = 6, n_vocab = 166, dim_feedforward = 512, n_head = 8, max_len = 128, device = "cpu", id2word = id2word):
        super().__init__()

        self.tok_emb_enc = nn.Embedding(n_vocab, d_model)
        self.pos_emb_enc = nn.Parameter(torch.zeros(1, max_len, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model, n_head, dim_feedforward= dim_feedforward)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers = num_layers, norm = nn.LayerNorm(d_model))

        self.tok_emb_dec = nn.Embedding(n_vocab, d_model)
        self.pos_emb_dec = nn.Parameter(torch.zeros(1, max_len, d_model))
        decoder_layer = nn.TransformerDecoderLayer(d_model, n_head, dim_feedforward= dim_feedforward)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers = num_layers, norm = nn.LayerNorm(d_model))
        self.head = nn.Linear(d_model, n_vocab)
        #self.init_token = n_vocab - 1
        self.id2word = id2word
        self.device = device



    def forward(self, inseq, outseq, mask):
        B, T_e = inseq.size()
        inseq = self.tok_emb_enc(inseq)
        inseq = self.pos_emb_enc[:, :T_e, :] + inseq
        inseq = self.encoder(inseq.transpose(0, 1)).transpose(0, 1)

        # appended 0:
        B, T_d = outseq.size()
        outseq = self.tok_emb_dec(outseq)
        outseq = self.pos_emb_dec[:, :T_d, :] + outseq

        mask = mask[0][:T_d, :T_d]

        output = self.decoder(tgt = outseq.transpose(0, 1), memory = inseq.transpose(0, 1), tgt_mask = mask ).transpose(0, 1)
        output = self.head(output)
        return output

    def toSOP(self, eq1, mask, gen_len = 63, T = 0.02):

        B, T = eq1.size()
        eq1 = self.tok_emb_enc(eq1)
        eq1 = self.pos_emb_enc[:, :T, :] + eq1
        eq1 = self.encoder(eq1.transpose(0, 1)).transpose(0, 1)
        gens = torch.ones(B, 1) * 1
        gens = gens.long().to(self.device)

        for i in range(1, gen_len + 1):
            eq2 = self.tok_emb_dec(gens)
            eq2 = self.pos_emb_dec[:, :i, :] + eq2
            out= self.decoder(tgt = eq2.transpose(0, 1), memory = eq1.transpose(0, 1), tgt_mask = mask[:i, :i]).transpose(0, 1)
            out = self.head(out)[:, -1, :]

            out = F.softmax(out, dim = -1)
            #print(out)
            _, sort_out = torch.sort(out, descending=True)
            #print(sort_out[:, :10])
            #out = torch.multinomial(out, num_samples=160)
            #print(out)
            #out = out.reshape(B, 1)
            #out = out.reshape(160, 1)
            sort_out = sort_out.reshape(166, 1)
            gens = torch.cat([gens, sort_out], dim = 0)
            break

        gens = gens.tolist()
        #  rets = []
        #for eq in gens:
        #    rets = []
        #    for i in eq:
        #        if i==0 or i==2:
        #            break
        #        else:
        #            rets.append(self.id2word[i])

        #print(gens)
        rets = []
        for i in gens:
            if 'rule' in self.id2word[i[0]]:
                rets.append(self.id2word[i[0]])

        return rets

device = torch.device("cuda")
model = vanillatransformer(d_model = 512, num_layers = 6, n_vocab = 166, dim_feedforward = 1024, n_head = 8, max_len = 384, device = device, id2word = id2word)
model.to(device)

checkpoint = torch.load('symmath/Multi_gpu_384_symmath_fulldata.pth')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
criterion = nn.CrossEntropyLoss()

def gettoken(tlist):
    tok=[]
    for i in tlist:
        if i == 0 or i == 2:
            return tok
        else:
            tok.append(id2word[i])

def getindex(tlist):
    ind=[]
    for i in tlist:
        ind.append(word2id[i])
    return ind
