import nltk
import random
import re
from collections import Counter
from nltk.corpus import stopwords
from nltk.grammar import Nonterminal
from nltk.grammar import ProbabilisticProduction
from nltk.parse.generate import generate
from nltk.parse import RecursiveDescentParser
from nltk.stem.porter import PorterStemmer
from nltk.stem.wordnet import WordNetLemmatizer
from nltk.tokenize import sent_tokenize
from nltk.tokenize import word_tokenize
from nltk import induce_pcfg
from nltk import Nonterminal
from nltk import PCFG
from nltk import CFG
from nltk import pos_tag, ne_chunk
from nltk import tree
from nltk import Tree
from typing import Iterator, List, Tuple, Union
import numpy as np

with open('../../bert_new/vocab.txt') as f:
    words = f.readlines()

words_list = []
for string in words:
    if string.strip('\n').isalpha():
        words_list.append(string.strip('\n'))

with open('rule_deep_mlm.txt') as f:
    x = f.readlines()

xp = []
for xx in x:
    n = len(xx.split())
    if n==3:
        if "S" == xx.split()[0][0]:
            xp.append(xx.split()[0] + ' -> ' + xx.split()[1] + ' [' + str((float(xx.split()[2].strip('\n')))) + ']\n')
        else:
            xp.append(xx.split()[0] + ' -> ' + '\'' + xx.split()[1] + '\'  [' + str((float(xx.split()[2].strip('\n')))) + ']\n')
    else:
            xp.append(xx.split()[0] + ' -> ' + xx.split()[1] + '  ' + xx.split()[2]  + '  [' + str((float(xx.split()[3].strip('\n')))) + ']\n')

Symbol = Union[str, Nonterminal]

class Generator(nltk.grammar.PCFG):
    def generate(self, n: int) -> Iterator[str]:
        for _ in range(n):
            yield self._generate_derivation(self.start())

    def _generate_derivation(self, nonterminal: Nonterminal) -> str:
        sentence: List[str] = []
        tree = ''
        symbol: Symbol
        derivation: str
        proba = 1
        #print(self._reduce_once(nonterminal)[0], self._reduce_once(nonterminal)[1])
        cc = self._reduce_once(nonterminal)
        tree += '('
        tree += nonterminal.__str__()
        tree += ' '
        proba = cc[1]
        for symbol in cc[0]:
            if isinstance(symbol, str):
                derivation = symbol
                tree += derivation
                
            else:
                derivation, probb, tree1 = self._generate_derivation(symbol)
                proba *= probb
                tree += tree1
                
            if derivation != "":
                sentence.append(derivation)
        tree += ')'
                
        return " ".join(sentence), proba, tree

    def _reduce_once(self, nonterminal: Nonterminal) -> Tuple[Symbol]:
        c, prob = self._choose_production_reducing(nonterminal)
        #print(c.rhs(), prob)
        return c.rhs(), prob

    def _choose_production_reducing(
        self, nonterminal: Nonterminal
    ) -> ProbabilisticProduction:
        productions: List[ProbabilisticProduction] = self._lhs_index[nonterminal]
        probabilities: List[float] = [production.prob() for production in productions]
        pairs = []
        for p,q in zip(productions, probabilities):
            pairs.append((p,q))
        return random.choices(pairs, weights=probabilities)[0]


generator = Generator.fromstring(''.join(xp)) 
n=1000000
sentences = []
while len(sentences) < n:
    for sentence in generator.generate(1):
        if len(sentence[0].split(' ')) < 32:
            sentences.append((sentence[0]))


generated = sentences


def cyk(sentence, grammar):
    def __producers(rhs, prob):
        """Given the terminals, rhs, and
        their joint probability (or 0 in the case of a terminal), prob,
        return all possible lhs's."""

        results = []

        productions = grammar._rhs_index[rhs]
        probabilities = [production.prob() for production in productions]

        for p,q in zip(productions, probabilities):
            results.append((p.lhs(),q * prob))

        return results

    def __producers_non(rhs, prob):
        """Given the rhs of a rule (e.g. "NP VP", "president"), rhs, and
        their joint probability (or 0 in the case of a terminal), prob,
        return all possible lhs's."""

        results = []
        
        if rhs[0] not in grammar._rhs_index.keys():
            return []

        productions = grammar._rhs_index[rhs[0]]
        probabilities = [production.prob() for production in productions]

        for p,q in zip(productions, probabilities):
            if len(p.rhs()) > 1 and p.rhs()[1] == rhs[1]:
                results.append((p.lhs(),q * prob))

        return results

    def __to_tree(table, pointer, sentence, j, i, k):
        """Trace back the pointer table recursively and return the parse tree."""

        if pointer[j][i]: #not empty
            rhs = '('

            #rhs1
            nj1 = pointer[j][i][k][0][0]
            ni1 = pointer[j][i][k][0][1]
            nk1 = pointer[j][i][k][0][2]
            rhs += (__to_tree(table, pointer, sentence, nj1, ni1, nk1))

            #rhs2
            nj2 = pointer[j][i][k][1][0]
            ni2 = pointer[j][i][k][1][1]
            nk2 = pointer[j][i][k][1][2]
            rhs += (__to_tree(table, pointer, sentence, nj2, ni2, nk2))
            
            rhs += ')'

        else: #empty
            rhs = sentence[i-1]

        tree = '(' + table[j][i][k][0]._symbol
        tree += ' '
        tree += rhs
        tree += ')'

        return tree
    
    def __print_table(table):
        """Print the dynamic programming table. Useful for debugging.
        The leftmost column is always empty."""

        for row in table:
            print(row[1:])
            
    sentence = sentence.split()
    length = len(sentence)
    table = [None] * (length)
    for j in range(length):
        table[j] = [None] * (length+1)
        for i in range(length+1):
            table[j][i] = []

    

    # Create a pointer table
    pointer = [None] * (length)
    for j in range(length):
        pointer[j] = [None] * (length+1)
        for i in range(length+1):
            pointer[j][i] = []

    # Fill the diagonal of the CYK table with parts-of-speech of the words
    for k in range(1, length+1):
        table[k-1][k].extend(__producers(sentence[k-1], 1))

    # Fill the CYK table
    for i in range (1, length+1):
        for j in range(i-2, -1, -1):
            current_lhs_prob = {}
            for k in range(j+1, i):
                # Test all combinations of rhslist
                for l in range(len(table[j][k])):
                    for m in range(len(table[k][i])):
                        prob = table[j][k][l][1] * table[k][i][m][1]
                        rhs = (table[j][k][l][0], table[k][i][m][0])
                        #print(rhs)
                        lhs = __producers_non(rhs, prob)
                        if lhs:
                            for (lfs, prob) in lhs:
                                if lfs not in current_lhs_prob.keys():
                                    current_lhs_prob[lfs] = (prob, [[j, k, l], [k, i, m]])
                                if prob > current_lhs_prob[lfs][0]:
                                    current_lhs_prob[lfs] = (prob, [[j, k, l], [k, i, m]])
            table[j][i].extend([(xx, current_lhs_prob[xx][0]) for xx in current_lhs_prob.keys()])
            pointer[j][i].extend([xx[1] for xx in current_lhs_prob.values()])
    
    #print(pointer)
    #__print_table(table) # Uncomment to print CYK table

    # Generate a parse tree and return it if the parse exists or
    # return None otherwise
    if table[0][length]:
        max_prob = table[0][length][0][1]
        max_idx = 0

        for i in range(1, len(table[0][length])):
            prob = table[0][length][i][1]
            if prob > max_prob:
                max_prob = prob
                max_idx = i

        return __to_tree(table, pointer, sentence, 0, length, max_idx)

    else:
        return None


def inside(sentence, grammar):
    def __producers(rhs, prob):
        """Given the terminals, rhs, and
        their joint probability (or 0 in the case of a terminal), prob,
        return all possible lhs's."""

        results = []

        productions = grammar._rhs_index[rhs]
        probabilities = [production.prob() for production in productions]

        for p,q in zip(productions, probabilities):
            results.append((p.lhs(),q * prob))

        return results

    def __producers_non(rhs, prob):
        """Given the rhs of a rule (e.g. "NP VP", "president"), rhs, and
        their joint probability (or 0 in the case of a terminal), prob,
        return all possible lhs's."""

        results = []
        
        if rhs[0] not in grammar._rhs_index.keys():
            return []

        productions = grammar._rhs_index[rhs[0]]
        probabilities = [production.prob() for production in productions]

        for p,q in zip(productions, probabilities):
            if len(p.rhs()) > 1 and p.rhs()[1] == rhs[1]:
                results.append((p.lhs(),q * prob))

        return results

    def __to_tree(table, pointer, sentence, j, i, k):
        """Trace back the pointer table recursively and return the parse tree."""

        if pointer[j][i]: #not empty
            rhs = '('

            #rhs1
            nj1 = pointer[j][i][k][0][0]
            ni1 = pointer[j][i][k][0][1]
            nk1 = pointer[j][i][k][0][2]
            rhs += (__to_tree(table, pointer, sentence, nj1, ni1, nk1))

            #rhs2
            nj2 = pointer[j][i][k][1][0]
            ni2 = pointer[j][i][k][1][1]
            nk2 = pointer[j][i][k][1][2]
            rhs += (__to_tree(table, pointer, sentence, nj2, ni2, nk2))
            
            rhs += ')'

        else: #empty
            rhs = sentence[i-1]

        tree = '(' + table[j][i][k][0]._symbol
        tree += ' '
        tree += rhs
        tree += ')'

        return tree
    
    def __print_table(table):
        """Print the dynamic programming table. Useful for debugging.
        The leftmost column is always empty."""

        for row in table:
            print(row[1:])
            
    sentence = sentence.split()
    length = len(sentence)
    table = [None] * (length)
    for j in range(length):
        table[j] = [None] * (length+1)
        for i in range(length+1):
            table[j][i] = []

    

    # Create a pointer table
    pointer = [None] * (length)
    for j in range(length):
        pointer[j] = [None] * (length+1)
        for i in range(length+1):
            pointer[j][i] = []

    # Fill the diagonal of the CYK table with parts-of-speech of the words
    for k in range(1, length+1):
        table[k-1][k].extend(__producers(sentence[k-1], 1))

    # Fill the CYK table
    for i in range (1, length+1):
        for j in range(i-2, -1, -1):
            current_lhs_prob = {}
            for k in range(j+1, i):
                # Test all combinations of rhslist
                for l in range(len(table[j][k])):
                    for m in range(len(table[k][i])):
                        prob = table[j][k][l][1] * table[k][i][m][1]
                        rhs = (table[j][k][l][0], table[k][i][m][0])
                        #print(rhs)
                        lhs = __producers_non(rhs, prob)
                        if lhs:
                            for (lfs, prob) in lhs:
                                if lfs not in current_lhs_prob.keys():
                                    current_lhs_prob[lfs] = prob
                                else:
                                    current_lhs_prob[lfs] += prob
            table[j][i].extend([(xx, current_lhs_prob[xx]) for xx in current_lhs_prob.keys()])
    
    
    root_prob = {"S1": 0, "S1": 0, "S3": 0, "S4": 0, "S5": 0, "S6": 0, "S7": 0, "S8": 0, "S9": 0}
    #print(table[0][length])
    for aa, bb in table[0][length]:
        #print(aa.symbol(), bb)
        if aa.symbol() in root_prob.keys():
            root_prob[aa.symbol()] += bb
    return root_prob


vocab = set()
for prod in generator.productions():
    for rhs in prod.rhs():
        if isinstance(rhs, str):
            vocab.add(rhs)

def cond_prob(sent):
    grammar = generator
    T = len(sent.split())
    prod_vec = []
    for t in range(T):
        prodlst = []
        sent1 = copy.deepcopy(sent)
        sent1_w = sent1.split()
        for w in vocab:
            sent1_w[t] = w
            prob_dict = inner(' '.join(sent1_w), grammar)
            prodlst.append(sum(prob_dict.values()))
        prodlst = np.array(prodlst) / np.sum(prodlst)
        prod_vec.append(prodlst)
    return np.concatenate(prod_vec)