import matplotlib.cm as cm
import matplotlib.colors as colors
import os
import torch
import subprocess
from pathlib import Path

def _apply_colormap(relevance, cmap):
    
    colormap = cm.get_cmap(cmap)
    return colormap(colors.Normalize(vmin=-1, vmax=1)(relevance))

def _generate_latex(words, relevances, cmap="bwr"):
    """
    Generate LaTeX code for a sentence with colored words based on their relevances.
    """

    # Generate LaTeX code
    latex_code = r'''

    \documentclass[arwidth=200mm]{standalone} 
    \usepackage[dvipsnames]{xcolor}
    
    \begin{document}
    \fbox{
    \parbox{\textwidth}{
    \setlength\fboxsep{0pt}
    '''

    for word, relevance in zip(words, relevances):
        rgb = _apply_colormap(relevance, cmap)
        r, g, b = int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255)

        if word.startswith('▁') or word.startswith('Ġ'):
            word = word.replace('▁', ' ').replace('Ġ', ' ')
            latex_code += f' \\colorbox[RGB]{{{r},{g},{b}}}{{\\strut {word}}}'
        else:
            latex_code += f' \\colorbox[RGB]{{{r},{g},{b}}}{{\\strut {word}}}'


    latex_code += r'}}\end{document}'

    return latex_code

    
def _compile_latex_to_pdf(latex_code, path='word_colors.pdf', delete_aux_files=True, backend='xelatex'):
    """
    Compile LaTeX code to a PDF file using pdflatex or xelatex.
    """
    
    # Save LaTeX code to a file
    path = Path(path)
    os.makedirs(path.parent, exist_ok=True)

    with open(path.with_suffix(".tex"), 'w') as f:
        f.write(latex_code)


    subprocess.call(['xelatex', '--output-directory', path.parent, path.with_suffix(".tex")])
#
    ## Use pdflatex to generate PDF file
    #if backend == 'pdflatex':
    #    subprocess.call(['pdflatex', '--output-directory', path.parent, path.with_suffix(".tex")])
    #elif backend == 'xelatex':
    #    subprocess.call(['xelatex', '--output-directory', path.parent, path.with_suffix(".tex")])
#
    #print("PDF file generated successfully.")
#
    #if delete_aux_files:
    #    for suffix in ['.aux', '.log', '.tex']:
    #        os.remove(path.with_suffix(suffix))


def pdf_heatmap(words, relevances, cmap="bwr", path='heatmap.pdf', delete_aux_files=True, backend='xelatex'):
    """
    Generate a PDF file with a heatmap of the relevances of the words in a sentence using LaTeX.

    Parameters
    ----------
    words : list of str
        The words in the sentence.
    relevances : list of float
        The relevances of the words normalized between -1 and 1.
    cmap : str
        The name of the colormap to use.
    path : str
        The path to save the PDF file.
    delete_aux_files : bool
        Whether to delete the auxiliary files generated by LaTeX.
    backend : str
        The LaTeX backend to use (pdflatex or xelatex).
    """

    assert len(words) == len(relevances), "The number of words and relevances must be the same."
    assert relevances.min() >= -1 and relevances.max() <= 1, "The relevances must be normalized between -1 and 1."

    latex_code = _generate_latex(words, relevances, cmap=cmap)
    _compile_latex_to_pdf(latex_code, path=path, delete_aux_files=delete_aux_files, backend=backend)


def clean_tokens(words):
    """
    Clean wordpiece tokens by removing special characters and splitting them into words.
    """

    if any("▁" in word for word in words):
        words = [word.replace("▁", " ") for word in words]
    
    elif any("Ġ" in word for word in words):
        words = [word.replace("Ġ", " ") for word in words]
    
    elif any("##" in word for word in words):
        words = [word.replace("##", "") if "##" in word else " " + word for word in words]
        words[0] = words[0].strip()

    else:
        raise ValueError("The tokenization scheme is not recognized.")
    
    special_characters = ['&', '%', '$', '#', '_', '{', '}', '\\']
    for i, word in enumerate(words):
        for special_character in special_characters:
            if special_character in word:
                words[i] = word.replace(special_character, '\\' + special_character)

    return words
