import re
import string
import random 
import sys 
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from random import sample
from nltk.tokenize import sent_tokenize, word_tokenize
from reader.vllm_reader import vllm_reader, vllm_reader_batch
from module_04_upgrade.prompt_template import *
from module_04_upgrade.example import *

def return_rand_list(selection_list, required_number):

    numbers = selection_list
    selected_numbers = sample(numbers, required_number)  # Picks and stores one random item from the list.
    return selected_numbers

def extract_phrases(text):
    # Use regex to find numbered phrases
    phrases = re.findall(r'\d+\.\s([^\n]+)', text)
    if not phrases:
        # If no numbered phrases found, try to find phrases without numbers
        phrases = re.findall(r'([^\n]+)', text)
    return phrases

class rule_rewrite:
    def __init__(self, constraint_instruction, constraint_type, modification, question, new_question, question_refine, answer, answer_split, original_answer, short_answer, keywords, title, sentence_count, words_count, nlp_tool, model, tokenizer, params):
        self.constraint_instruction = constraint_instruction
        self.constraint_type = constraint_type
        self.modification = modification
        self.question = question
        self.new_question = new_question
        self.question_refine = question_refine
        self.answer = answer
        self.answer_split = [element for element in answer_split if element!=""]
        self.original_answer = original_answer
        self.short_answer = short_answer
        self.keywords = [x for x in list(dict.fromkeys([word.lower() for word in keywords])) if " " not in x]
        self.keywords = [x for x in self.keywords if not any(c.isdigit() for c in x)]  # Remove keywords with digits
        self.sentence_count = sentence_count
        self.paragraphs_count = len(answer_split)
        self.title = title
        self.nlp_tool = nlp_tool
        self.model = model
        self.tokenizer = tokenizer
        self.params = params
        self.words_count = words_count
        self.sequence_dict = {0: ["first", "1st"], 1: ["second", "2nd"], 2: ["third", "3rd"], 3: ["fourth", "4th"], 4: ["fifth", "5th"], 5: ["sixth", "6th"], 6: ["seventh", "7th"], 7: ["eighth", "8th"], 8: ["ninth", "9th"], 9: ["tenth", "10th"], 10: ["eleventh", "11th"], 11: ["twelfth", "12th"], 12: ["thirteenth", "13th"], 13: ["fourteenth", "14th"], 14: ["fifteenth", "15th"], 15: ["sixteenth", "16th"], 16: ["seventeenth", "17th"], 17: ["eighteenth", "18th"], 18: ["nineteenth", "19th"], 19: ["twentieth", "20th"], 20: ["twenty-first", "21st"], 21: ["twenty-second", "22nd"], 22: ["twenty-third", "23rd"], 23: ["twenty-fourth", "24th"], 24: ["twenty-fifth", "25th"], 25: ["twenty-sixth", "26th"], 26: ["twenty-seventh", "27th"], 27: ["twenty-eighth", "28th"], 28: ["twenty-ninth", "29th"], 29: ["thirtieth", "30th"], 30: ["thirty-first", "31st"], 31: ["thirty-second", "32nd"], 32: ["thirty-third", "33rd"], 33: ["thirty-fourth", "34th"], 34: ["thirty-fifth", "35th"], 35: ["thirty-sixth", "36th"], 36: ["thirty-seventh", "37th"], 37: ["thirty-eighth", "38th"], 38: ["thirty-ninth", "39th"], 39: ["fortieth", "40th"], 40: ["forty-first", "41st"], 41: ["forty-second", "42nd"], 42: ["forty-third", "43rd"], 43: ["forty-fourth", "44th"], 44: ["forty-fifth", "45th"], 45: ["forty-sixth", "46th"], 46: ["forty-seventh", "47th"], 47: ["forty-eighth", "48th"], 48: ["forty-ninth", "49th"], 49: ["fiftieth", "50th"]}
    
    def constraint_selection(self):
        for i, (key, value) in enumerate(self.constraint_type.items()):
            constraint_category = key
            constraint_id = value

            match constraint_category:
                case "1_annotation":
                    self.annotation(constraint_id, i)
                case "1_annotation_paragraph":
                    self.annotation_paragraph(constraint_id, i)
                case "2_caps":
                    self.caps(constraint_id, i)
                case "2_caps_paragraph":
                    self.caps_paragraph(constraint_id, i)
                case "3_decoration":
                    self.decoration(constraint_id, i)
                case "3_decoration_symbol":
                    self.decoration_symbol(constraint_id, i)
                case "4_decoration_paragraph":
                    self.decoration_paragraph(constraint_id, i)
                case "4_decoration_paragraph_symbol":
                    self.decoration_paragraph_symbol(constraint_id, i)
                case "6_keywords":
                    self.keyword(constraint_id, i)
                case "7_punctuation":
                    self.punctuation(constraint_id, i)
                case "8_structure":
                    self.structure(constraint_id, i)
                case "8_structure_paragraph":
                    self.structure(constraint_id, i)

# ---------------------internal function(start)---------------------

    def to_dict(self):
        return {
            "constraint_instruction": self.constraint_instruction,
            "constraint_type": self.constraint_type,
            "modification": self.modification,
            "question": self.question,
            "new_question": self.new_question,
            "question_refine": self.question_refine,
            "original_answer": self.original_answer,
            "answer": self.answer,
            "short_answer": self.short_answer,
            "answer_split": self.answer_split,
            "sentence_count": self.sentence_count,
            "keywords": self.keywords,
            "title": self.title,
            "paragraphs_count": self.paragraphs_count,
            "words_count": self.words_count,
        }
            
    def caps_phrase(self, selected_phrases, num, instruction_id):
        """
        Uppercases multi-word phrases (like "honey tea" or "john walker") in the text.
        """

        # 1. Sort the phrases by length in descending order
        # 1. Sort phrases by length in descending order to avoid partial overlaps
        selected_phrases = sorted(selected_phrases, key=len, reverse=True)

        modified_sentences = []

        for sentence in self.answer_split:
            modified_sentence = sentence.replace("\"", "")  # Remove quotes if desired

            # 2. Apply each phrase replacement
            for phrase in selected_phrases:
                escaped_phrase = re.escape(phrase).replace(r'\ ', r'\s+')
                pattern = re.compile(escaped_phrase, re.IGNORECASE)

                # Replace matched text with its UPPERCASE version
                modified_sentence = pattern.sub(
                    lambda m: m.group(0).upper(),
                    modified_sentence
                )

            modified_sentences.append(modified_sentence.strip())

        # 3. Join modified sentences back into one string
        concatenated_string = " ".join(modified_sentences)

        # 4. Update object attributes
        self.answer_split = modified_sentences
        self.answer = concatenated_string

        # 5. Update instruction with number if needed
        if num > 1:
            self.constraint_instruction[instruction_id] = \
                self.constraint_instruction[instruction_id][0].replace("{num}", str(num))
        else:
            self.constraint_instruction[instruction_id] = \
                self.constraint_instruction[instruction_id][1]
            
    def caps_paragraph_phrase(self, selected_phrases, num, instruction_id):
        """
        Uppercases multi-word phrases in a list of paragraphs (list of lists of sentences).
        """
        # 1. Sort phrases by length in descending order to avoid partial overlaps
        selected_phrases = sorted(selected_phrases, key=len, reverse=True)

        modified_paragraphs = []
        concatenated_paragraphs_list = []

        for paragraph in self.answer_split:  # each paragraph is a list of sentences
            modified_sentences = []

            for sentence in paragraph:
                # Remove quotes, if desired
                modified_sentence = sentence

                # 2. For each phrase, create a flexible-regex pattern and replace
                for phrase in selected_phrases:
                    # Escape regex characters, then replace the escaped space with \s+
                    escaped_phrase = re.escape(phrase)
                    escaped_phrase = escaped_phrase.replace(r'\,', r'\s*,\s*')  # optional space around commas
                    escaped_phrase = escaped_phrase.replace(r'\ ', r'\s+')      # flexible whitespace
                    pattern = re.compile(escaped_phrase, re.IGNORECASE)

                    # Replace matched text with its UPPERCASE version
                    modified_sentence = pattern.sub(
                        lambda m: m.group(0).upper(),
                        modified_sentence
                    )

                # Strip trailing whitespace and store result
                modified_sentences.append(modified_sentence.strip())

            # Store modified sentences (list of strings) in modified_paragraphs
            modified_paragraphs.append(modified_sentences)

            # Also join each paragraph's sentences for final concatenation
            concatenated_paragraphs_list.append(" ".join(modified_sentences))

        # 3. Join all paragraphs with double newlines
        concatenated_paragraphs = "\n\n".join(concatenated_paragraphs_list)

        # 4. Update the object attributes
        self.answer_split = modified_paragraphs
        self.answer = concatenated_paragraphs

        # 5. Replace {num} placeholder in constraint_instruction
        if num > 1:
            self.constraint_instruction[instruction_id] = \
                self.constraint_instruction[instruction_id][0].replace("{num}", str(num))
        else:
            self.constraint_instruction[instruction_id] = \
                self.constraint_instruction[instruction_id][1]


    def highlight_phrase(self, selected_phrases, num, instruction_id, constraint_id):
        """
        Highlights multi-word phrases (like "honey tea" or "john walker") by enclosing
        them with symbols specified in self.modification[constraint_id].
        """

        # 1. Sort the phrases by length in descending order
        #    This avoids conflicts where a smaller phrase is part of a bigger one.
        selected_phrases = sorted(selected_phrases, key=len, reverse=True)
        modified_sentences = []

        for sentence in self.answer_split:
            # If you do NOT want to remove quotes, comment out the next line:
            modified_sentence = sentence.replace("\"", "")

            # 2. Apply each phrase replacement
            for phrase in selected_phrases:
                # Escape special regex chars, but allow spaces to become "\s+" (one or more whitespace)
                escaped_phrase = re.escape(phrase).replace(r'\ ', r'\s+')

                # Compile a case-insensitive pattern
                pattern = re.compile(escaped_phrase, re.IGNORECASE)

                # Enclose the match with the symbols from self.modification[constraint_id]
                modified_sentence = pattern.sub(
                    lambda m: f"{self.modification[constraint_id][0]}{m.group(0).strip(string.punctuation)}{self.modification[constraint_id][1]}",
                    modified_sentence
                )

            # Strip trailing/leading whitespace (optional), then collect the result
            modified_sentences.append(modified_sentence.strip())

        # 3. Join the modified sentences back into a single string
        concatenated_string = " ".join(modified_sentences)

        # 4. Update the object's attributes
        self.answer_split = modified_sentences
        self.answer = concatenated_string

        # 5. Replace {num} with the actual number in constraint_instruction
        if num>1:
            self.constraint_instruction[instruction_id] = \
                self.constraint_instruction[instruction_id][0].replace("{num}", str(num))
        else:
            self.constraint_instruction[instruction_id] = \
                self.constraint_instruction[instruction_id][1]

    def highlight_paragraph_phrase(self, selected_phrases, num, instruction_id, constraint_id):
        """
        Highlights multi-word phrases in a list of paragraphs (list of lists of sentences).
        Each matched phrase is enclosed with symbols specified in self.modification[constraint_id].
        """

        # 1. Sort phrases by length in descending order to avoid partial overlaps
        selected_phrases = sorted(selected_phrases, key=len, reverse=True)

        modified_paragraphs = []
        concatenated_paragraphs_list = []

        for paragraph in self.answer_split:  # each paragraph is a list of sentences
            modified_sentences = []
            
            for sentence in paragraph:
                # Remove quotes, if desired
                modified_sentence = sentence.replace('"', '')

                # 2. For each phrase, create a flexible-regex pattern and replace
                for phrase in selected_phrases:
                    # Escape regex characters, then replace the escaped space with \s+
                    escaped_phrase = re.escape(phrase).replace(r'\ ', r'\s+')
                    pattern = re.compile(escaped_phrase, re.IGNORECASE)
                    
                    # Surround the matched phrase with the desired modification symbols
                    start_sym, end_sym = self.modification[constraint_id]
                    modified_sentence = pattern.sub(
                        lambda m: f"{start_sym}{m.group(0).strip(string.punctuation)}{end_sym}",
                        modified_sentence
                    )
                
                # Strip trailing whitespace and store result
                modified_sentences.append(modified_sentence.strip())
            
            # Store modified sentences (list of strings) in modified_paragraphs
            modified_paragraphs.append(modified_sentences)
            
            # Also join each paragraph's sentences for final concatenation
            concatenated_paragraphs_list.append(" ".join(modified_sentences))

        # 3. Join all paragraphs with double newlines
        concatenated_paragraphs = "\n\n".join(concatenated_paragraphs_list)

        # 4. Update the object attributes
        self.answer_split = modified_paragraphs
        self.answer = concatenated_paragraphs

        # 5. Replace {num} placeholder in constraint_instruction
        if num>1:

            self.constraint_instruction[instruction_id] = \
                self.constraint_instruction[instruction_id][0].replace("{num}", str(num))
        else:
            self.constraint_instruction[instruction_id] = \
                self.constraint_instruction[instruction_id][1]

# ---------------------internal function(end)---------------------        
    def caps(self, constraint_id, instruction_id):
        if constraint_id == "caps_keywords":
            
            if self.keywords:
                selected_keywords = self.keywords
                selected_keywords = return_rand_list(selected_keywords,return_rand_list([i for i in range(1,len(selected_keywords)+1)],1)[0])
                if len(selected_keywords)>5:
                    selected_keywords = selected_keywords[:5]
                
                num = len(selected_keywords)
                self.caps_phrase(selected_keywords, num, instruction_id)

            else:
                self.constraint_type["2_caps"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"
        
        # no capital letters
        if constraint_id=="caps_no_caps":
            modified_sentences = []
            for sentence in self.answer_split:
                modified_sentences.append(sentence.lower())
            
            modified_sentence = " ".join(modified_sentences)

            self.answer_split = modified_sentences
            self.answer = modified_sentence
        
        # all capital letters
        if constraint_id=="caps_only":
            modified_sentences = []
            for sentence in self.answer_split:
                modified_sentences.append(sentence.upper().strip())
            
            modified_sentence = " ".join(modified_sentences)

            self.answer_split = modified_sentences
            self.answer = modified_sentence

        if constraint_id=="caps_only_capital":
            modified_sentences = []
            for sentence in self.answer_split:
                new_sentence_split = []
                for word in sentence.strip().split():
                    new_sentence_split.append(word.capitalize())
                modified_sentences.append(" ".join(new_sentence_split))
                    
            
            modified_sentence = " ".join(modified_sentences)

            self.answer_split = modified_sentences
            self.answer = modified_sentence
        
        # make the first word of each sentence all uppercase
        if constraint_id=="caps_first_word":
            modified_sentences = []
            for sentence in self.answer_split:
                words = sentence.split()
                modified_words = [
                    word.upper() if i==0 else word
                    for i, word in enumerate(words)
                ]
                modified_sentence = " ".join(modified_words)
                modified_sentences.append(modified_sentence.strip())
            
            modified_sentence = " ".join(modified_sentences)

            self.answer_split = modified_sentences
            self.answer = modified_sentence
        
        # make the last word of each sentence all uppercase
        if constraint_id=="caps_last_word":
            modified_sentences = []
            for sentence in self.answer_split:
                words = sentence.split()
                modified_words = [
                    word.upper() if i==len(words)-1 else word
                    for i, word in enumerate(words)
                ]
                modified_sentence = " ".join(modified_words)
                modified_sentences.append(modified_sentence.strip())
            
            modified_sentence = " ".join(modified_sentences)

            self.answer_split = modified_sentences
            self.answer = modified_sentence
        
        if constraint_id=="caps_sentence":
            if self.sentence_count>1:
                modified_sentences = []
                selected_sentence = return_rand_list([i for i in range(0, len(self.answer_split))], 1)[0]
                for i, sentence in enumerate(self.answer_split):
                    if i == selected_sentence:
                        modified_sentences.append(sentence.upper())
                    else:
                        modified_sentences.append(sentence)

                # Join the modified sentences into a single string
                self.answer_split = modified_sentences
                self.answer = " ".join(modified_sentences)
                self.constraint_instruction[instruction_id] = \
                    self.constraint_instruction[instruction_id].replace("{num}", random.choice(self.sequence_dict[selected_sentence]))
            
            else:
                self.constraint_type["2_caps"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"


    def caps_paragraph(self, constraint_id, instruction_id):
        if constraint_id == "caps_keywords":

            if self.keywords:
                selected_keywords = self.keywords
                selected_keywords = return_rand_list(selected_keywords,random.choice([i for i in range(1,len(selected_keywords)+1)]))

                if len(selected_keywords)>5:
                    selected_keywords = selected_keywords[:5]
                
                num = len(selected_keywords)
                self.caps_paragraph_phrase(selected_keywords, num, instruction_id)
            
            else:   
                self.constraint_type["2_caps_paragraph"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"

        # Caps first word of each paragraph
        if constraint_id == "caps_first_word":
            modified_paragraphs = []
            for paragraph in self.answer_split:
                if paragraph:  # Ensure the paragraph is not empty
                    words = paragraph[0].split()
                    if words:
                        words[0] = words[0].upper()  # Capitalize only the first word
                        paragraph[0] = " ".join(words)
                modified_paragraphs.append(paragraph)  # Store modified paragraph

            self.answer_split = modified_paragraphs
            self.answer = "\n\n".join([" ".join(paragraph) for paragraph in self.answer_split])
        
        # no capital letters
        if constraint_id == "caps_no_caps":
            modified_paragraphs = []
            concatenated_paragraphs = []

            for paragraph in self.answer_split:
                modified_sentences = [sentence.lower() for sentence in paragraph]
                modified_paragraphs.append(modified_sentences)
                concatenated_paragraphs.append(" ".join(modified_sentences))

            self.answer_split = modified_paragraphs
            self.answer = "\n\n".join(concatenated_paragraphs)

        # all capital letters
        if constraint_id == "caps_only":
            modified_paragraphs = []
            concatenated_paragraphs = []

            for paragraph in self.answer_split:
                modified_sentences = [sentence.upper() for sentence in paragraph]
                modified_paragraphs.append(modified_sentences)
                concatenated_paragraphs.append(" ".join(modified_sentences))

            self.answer_split = modified_paragraphs
            self.answer = "\n\n".join(concatenated_paragraphs)

        if constraint_id=="caps_only_capital":
            modified_paragraphs = []
            concatenated_paragraphs = ""
            for paragraph in self.answer_split:
                modified_sentences = []
                concatenated_string = []
                for sentence in paragraph:
                    new_sentence_split = []
                    for word in sentence.strip().split():
                        new_sentence_split.append(word.capitalize())
                    modified_sentences.append(" ".join(new_sentence_split)) 

                concatenated_string.append(" ".join(modified_sentences))
                modified_paragraphs.append(modified_sentences)
            concatenated_paragraphs = "\n\n".join(concatenated_string)

            self.answer_split = modified_paragraphs
            self.answer = concatenated_paragraphs
        
        if constraint_id == "caps_paragraph":

            new_paragraph = []
            # return random paragraphs number
            selected_paragraphs = return_rand_list([i for i in range(0, len(self.answer_split))], 1)[0]

            for sentence in self.answer_split[selected_paragraphs]:
                new_paragraph.append(sentence.upper())
            
            self.answer_split[selected_paragraphs] = new_paragraph
            self.answer = "\n\n".join([" ".join(paragraph) for paragraph in self.answer_split])

            # replace the placeholder in the constraint instruction
            if selected_paragraphs!=len(self.answer_split)-1:
                self.constraint_instruction[instruction_id] = \
                    self.constraint_instruction[instruction_id].replace("{num}", random.choice(self.sequence_dict[selected_paragraphs]))
            
            else:
                self.constraint_instruction[instruction_id] = \
                    self.constraint_instruction[instruction_id].replace("{num}", "last")

    
    def annotation(self, constraint_id, instruction_id):
        if constraint_id == "highlight":

            if self.keywords:
                selected_keywords = self.keywords
                selected_keywords = return_rand_list(selected_keywords,return_rand_list([i for i in range(1,len(selected_keywords)+1)],1)[0])
                num = len(selected_keywords)

                if len(selected_keywords)>5:
                    selected_keywords = selected_keywords[:5]
                    num = len(selected_keywords)
                
                self.highlight_phrase(selected_keywords, num, instruction_id, constraint_id)
                self.keywords = selected_keywords


            else:
                self.constraint_type["1_annotation"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"

    def annotation_paragraph(self, constraint_id, instruction_id):
        if constraint_id == "highlight":

            if self.keywords:
                selected_keywords = self.keywords
                selected_keywords = return_rand_list(selected_keywords,return_rand_list([i for i in range(1,len(selected_keywords)+1)],1)[0])
                num = len(selected_keywords)

                if len(selected_keywords)>5:
                    selected_keywords = selected_keywords[:5]
                    num = len(selected_keywords)
                
                self.highlight_paragraph_phrase(selected_keywords, num, instruction_id, constraint_id)
                self.keywords = selected_keywords

            else:
                self.constraint_type["1_annotation_paragraph"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"
        
    def decoration_symbol(self, constraint_id, instruction_id):
        if constraint_id == "start_self":
            word_split = self.answer.split()
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[0])
        
        if constraint_id == "end_self":
            word_split = self.answer.split()
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[-2])

        if constraint_id == "start_sentence_self":
            if self.sentence_count>1:
                selected_sentence = return_rand_list([i for i in range(0, len(self.answer_split))], 1)[0]
                word_split = self.answer_split[selected_sentence].split()
                
                if selected_sentence!=len(self.answer_split)-1:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[0]).replace("{num}", random.choice(self.sequence_dict[selected_sentence]))
                else:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[0]).replace("{num}", "last")
                
            else:
                self.constraint_type["3_decoration_symbol"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"

        if constraint_id == "end_sentence_self":
            if self.sentence_count>1:
                selected_sentence = return_rand_list([i for i in range(0, len(self.answer_split))], 1)[0]
                word_split = self.answer_split[selected_sentence].split()
                if selected_sentence!=len(self.answer_split)-1:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[-2]).replace("{num}", random.choice(self.sequence_dict[selected_sentence]))
                else:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[-2]).replace("{num}", "last")

            else:
                self.constraint_type["3_decoration_symbol"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"

        if constraint_id=="start":
            self.answer = self.modification[constraint_id][0] + self.answer
            new_first_sentence = self.modification[constraint_id][0]+self.answer_split[0]
            self.answer_split[0] = new_first_sentence

        if constraint_id=="end":
            self.answer = self.answer + self.modification[constraint_id][0]
            new_last_sentence = self.answer_split[-1]+self.modification[constraint_id][0]
            self.answer_split[-1] = new_last_sentence

        if constraint_id in ["enclose","start_end"]:
            self.answer = self.modification[constraint_id][0] + self.answer + self.modification[constraint_id][1]
            new_first_sentence = self.modification[constraint_id][0]+self.answer_split[0]
            new_last_sentence = self.answer_split[-1] +self.modification[constraint_id][1]
            self.answer_split[0] = new_first_sentence
            self.answer_split[-1] = new_last_sentence

        if constraint_id=="title":
            self.answer = self.title[0].strip()+'\n\n'+self.answer
        
        if constraint_id=="title_bracket":

            self.title = [f"{self.modification[constraint_id][0]}{self.title[0].strip()}{self.modification[constraint_id][1]}"]
            self.answer = f"{self.title[0]}\n\n{self.answer}"

        if constraint_id =="sentence_end":
            new_answer_split = []
            for sentence in self.answer_split:
                new_answer_split.append(sentence.strip()+self.modification[constraint_id][0])
            
            self.answer_split = new_answer_split
            self.answer = " ".join(new_answer_split)
        
        if constraint_id =="sentence_start":
            new_answer_split = []
            for sentence in self.answer_split:
                new_answer_split.append(self.modification[constraint_id][0]+sentence.strip())
            
            self.answer_split = new_answer_split
            self.answer = " ".join(new_answer_split)
        
        if constraint_id =="sentence_enclose":
            new_answer_split = []
            for sentence in self.answer_split:
                new_answer_split.append(self.modification[constraint_id][0]+sentence.strip()+self.modification[constraint_id][1])
            
            self.answer_split = new_answer_split
            self.answer = " ".join(new_answer_split)
        
        if constraint_id =="sentence_enclose_first":
            if len(self.answer_split)>1:
                new_answer_split = []
                for i, sentence in enumerate(self.answer_split):
                    if i == 0:
                        new_answer_split.append(self.modification[constraint_id][0]+sentence.strip()+self.modification[constraint_id][1])
                    else:
                        new_answer_split.append(sentence.strip())
                self.answer_split = new_answer_split
                self.answer = " ".join(new_answer_split)
            
            else:
                self.constraint_type["3_decoration_symbol"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"
        
        if constraint_id =="sentence_enclose_last":
            if len(self.answer_split)>1:
                new_answer_split = []
                for i, sentence in enumerate(self.answer_split):
                    if i == len(self.answer_split)-1:
                        new_answer_split.append(self.modification[constraint_id][0]+sentence.strip()+self.modification[constraint_id][1])
                    else:
                        new_answer_split.append(sentence.strip())
                self.answer_split = new_answer_split
                self.answer = " ".join(new_answer_split)
            
            else:
                self.constraint_type["3_decoration_symbol"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"
            
    def decoration(self, constraint_id, instruction_id):
        if constraint_id == "start_self":
            word_split = self.answer.split()
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[0])
        
        if constraint_id == "end_self":
            word_split = self.answer.split()
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[-2])

        if constraint_id == "start_sentence_self":
            if self.sentence_count>1:
                selected_sentence = return_rand_list([i for i in range(0, len(self.answer_split))], 1)[0]
                word_split = self.answer_split[selected_sentence].split()
                
                if selected_sentence!=len(self.answer_split)-1:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[0]).replace("{num}", random.choice(self.sequence_dict[selected_sentence]))
                else:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[0]).replace("{num}", "last")
                
            else:
                self.constraint_type["3_decoration"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"

        if constraint_id == "end_sentence_self":
            if self.sentence_count>1:
                selected_sentence = return_rand_list([i for i in range(0, len(self.answer_split))], 1)[0]
                word_split = self.answer_split[selected_sentence].split()
                if selected_sentence!=len(self.answer_split)-1:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[-2]).replace("{num}", random.choice(self.sequence_dict[selected_sentence]))
                else:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[-2]).replace("{num}", "last")

            else:
                self.constraint_type["3_decoration"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"
        
        if constraint_id=="start":
            selected_phrase = return_rand_list(self.modification[constraint_id], 1)[0]
            self.answer = selected_phrase + self.answer
            new_first_sentence = selected_phrase+self.answer_split[0]
            self.answer_split[0] = new_first_sentence
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", selected_phrase.strip())

        if constraint_id=="end":
            selected_phrase = return_rand_list(self.modification[constraint_id], 1)[0]
            self.answer = self.answer + selected_phrase
            new_last_sentence = self.answer_split[-1] + selected_phrase
            self.answer_split[-1] = new_last_sentence
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", selected_phrase.strip())

        if constraint_id == "enclose":
            self.answer = self.modification[constraint_id][0] + self.answer + self.modification[constraint_id][1]
            new_first_sentence = self.modification[constraint_id][0]+self.answer_split[0]
            new_last_sentence = self.answer_split[-1]+self.modification[constraint_id][1]
            self.answer_split[0] = new_first_sentence
            self.answer_split[-1] = new_last_sentence

        if constraint_id=="title":
            self.answer = self.title[0].strip()+'\n\n'+self.answer
        
        if constraint_id =="title_all_caps":
            self.title = [self.title[0].strip().upper()]
            self.answer = f"{self.title[0]}\n\n{self.answer}"
        
        if constraint_id =="title_no_caps":
            self.title = [self.title[0].strip().lower()]
            self.answer = f"{self.title[0]}\n\n{self.answer}"
        
        if constraint_id =="sentence_end":
            new_answer_split = []
            for sentence in self.answer_split:
                new_answer_split.append(sentence.strip()+self.modification[constraint_id][0])
            
            self.answer_split = new_answer_split
            self.answer = " ".join(new_answer_split)
        
        if constraint_id =="sentence_start":
            new_answer_split = []
            for sentence in self.answer_split:
                new_answer_split.append(self.modification[constraint_id][0]+sentence.strip())
            
            self.answer_split = new_answer_split
            self.answer = " ".join(new_answer_split)
        
        if constraint_id =="sentence_separate":

            if len(self.answer_split)>1:
                new_answer_split = []
                # add the modification tag at the end of each sentence except the last sentence
                for i, sentence in enumerate(self.answer_split):
                    if i < len(self.answer_split)-1:
                        new_answer_split.append(sentence.strip() + self.modification[constraint_id][0])
                    else:
                        new_answer_split.append(sentence.strip())
                self.answer_split = new_answer_split
                self.answer = " ".join(new_answer_split)
            
            else:
                self.constraint_type["3_decoration"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"

    def decoration_paragraph(self, constraint_id, instruction_id):

        if constraint_id == "start_self":
            word_split = self.answer.split()
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[0])

        if constraint_id == "end_self":
            word_split = self.answer.split()
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[-2])

        if constraint_id == "paragraph_start_self":
            selected_paragraph = return_rand_list([i for i in range(0, len(self.answer_split))], 1)[0]
            word_split = self.answer_split[selected_paragraph][0].split()
            if selected_paragraph!=len(self.answer_split)-1:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[0]).replace("{num}", random.choice(self.sequence_dict[selected_paragraph]))
            else:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[0]).replace("{num}", "last")

        if constraint_id == "paragraph_end_self":
            selected_paragraph = return_rand_list([i for i in range(0, len(self.answer_split))], 1)[0]
            word_split = self.answer_split[selected_paragraph][-1].split()
            if selected_paragraph!=len(self.answer_split)-1:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[-1]).replace("{num}", random.choice(self.sequence_dict[selected_paragraph]))
            else:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[-1]).replace("{num}", "last")

        if constraint_id == "paragraph_start":
            new_answer_split = []
            selected_phrase = return_rand_list(self.modification[constraint_id], 1)[0]

            for paragraph in self.answer_split:
                new_paragraph_split = []
                for i, sentence in enumerate(paragraph):
                    if i == 0:
                        new_paragraph_split.append(selected_phrase +sentence.strip())
                    else:
                        new_paragraph_split.append(sentence.strip())
                new_answer_split.append(new_paragraph_split)
            
            concatenated_paragraphs = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs

            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", selected_phrase.strip())

        if constraint_id == "paragraph_end":
            new_answer_split = []
            selected_phrase = return_rand_list(self.modification[constraint_id], 1)[0]

            for paragraph in self.answer_split:
                new_paragraph_split = []
                for i, sentence in enumerate(paragraph):
                    if i == len(paragraph)-1:
                        new_paragraph_split.append(sentence.strip()+selected_phrase)
                    else:
                        new_paragraph_split.append(sentence.strip())
                new_answer_split.append(new_paragraph_split)
            
            concatenated_paragraphs = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs

            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", selected_phrase.strip())   
        
        if constraint_id == "separation":
            new_answer_split = []

            for j, paragraph in enumerate(self.answer_split):
                new_paragraph_split = []
                if j == len(self.answer_split)-1:
                    for sentence in paragraph:
                        new_paragraph_split.append(sentence.strip())
                else:
                    for i, sentence in enumerate(paragraph):
                        if i == len(paragraph)-1:
                            new_paragraph_split.append(sentence.strip()+"\n"+self.modification[constraint_id][0])
                        else:
                            new_paragraph_split.append(sentence.strip())
                new_answer_split.append(new_paragraph_split)
            
            concatenated_paragraphs = "\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs

        
        if constraint_id == "paragraph_label":
            new_answer_split = []
            alpha_list = list(string.ascii_uppercase)

            for paragraph,alpha in zip(self.answer_split, alpha_list):
                new_paragraph_split = []
                for i, sentence in enumerate(paragraph):
                    if i==0:
                        new_paragraph_split.append(f"{self.modification[constraint_id][0]}{alpha}{self.modification[constraint_id][1]} {sentence.strip()}")
                    else:
                        new_paragraph_split.append(sentence.strip())
                new_answer_split.append(new_paragraph_split)

            concatenated_paragraphs = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs
        
        if constraint_id == "paragraph_label_num":
            new_answer_split = []

            for j,paragraph in enumerate(self.answer_split):
                new_paragraph_split = []
                for i, sentence in enumerate(paragraph):
                    if i==0:
                        new_paragraph_split.append(f"{self.modification[constraint_id][0]}{j+1}{self.modification[constraint_id][1]} {sentence.strip()}")
                    else:
                        new_paragraph_split.append(sentence.strip())
                new_answer_split.append(new_paragraph_split)

            concatenated_paragraphs = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs

        if constraint_id == "paragraph_title":
            
            new_answer_split = []
            for i, paragraph in enumerate(self.answer_split):
                current_paragraph = " ".join(paragraph)
                paragraph = self.title[i].strip()+"\n\n"+current_paragraph
                new_answer_split.append(paragraph)
            
            concatenated_paragraphs = "\n\n".join(new_answer_split)
            self.answer = concatenated_paragraphs
    
        
        if constraint_id == "paragraph_title_enclose":
            new_answer_split = []

            for i, paragraph in enumerate(self.answer_split):
                current_paragraph = " ".join(paragraph)
                paragraph = self.modification[constraint_id][0]+self.title[i].strip()+self.modification[constraint_id][1]+"\n\n"+current_paragraph
                new_answer_split.append(paragraph)

            concatenated_paragraphs = "\n\n".join(new_answer_split)
            self.answer = concatenated_paragraphs

        if constraint_id == "paragraph_title_all_caps":
            new_answer_split = []
            for i, paragraph in enumerate(self.answer_split):
                current_paragraph = " ".join(paragraph)
                paragraph = self.title[i].strip().upper()+"\n\n"+current_paragraph
                new_answer_split.append(paragraph)

            concatenated_paragraphs = "\n\n".join(new_answer_split)
            self.answer = concatenated_paragraphs
        
        if constraint_id == "paragraph_title_no_caps":
            new_answer_split = []
            for i, paragraph in enumerate(self.answer_split):
                current_paragraph = " ".join(paragraph)
                paragraph = self.title[i].strip().lower()+"\n\n"+current_paragraph
                new_answer_split.append(paragraph)
            
            concatenated_paragraphs = "\n\n".join(new_answer_split)
            self.answer = concatenated_paragraphs

        if constraint_id == "paragraph_enclose":
            new_answer_split = []

            for paragraph in self.answer_split:
                new_paragraph_split = []
                if len(paragraph)>1:
                    for i, sentence in enumerate(paragraph):
                        if i == 0:
                            new_paragraph_split.append(self.modification[constraint_id][0]+sentence.strip())
                        elif i == len(paragraph)-1:
                            new_paragraph_split.append(sentence.strip()+self.modification[constraint_id][1])
                        else:
                            new_paragraph_split.append(sentence.strip())
                else:
                    new_paragraph_split.append(self.modification[constraint_id][0]+paragraph[0].strip()+self.modification[constraint_id][1])
                
                new_answer_split.append(new_paragraph_split)
            
            concatenated_paragraphs = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs

        if constraint_id == "paragraph_enclose_random":
            new_answer_split = []
            total_paragraph = len(self.answer_split)
            selected_paragraph = return_rand_list([i for i in range(total_paragraph)],random.choice([i for i in range(1,total_paragraph)]))

            for x, paragraph in enumerate(self.answer_split):
                new_paragraph_split = []
                if x in selected_paragraph:
                    if len(paragraph)>1:
                        for i, sentence in enumerate(paragraph):
                            if i == 0:
                                new_paragraph_split.append(self.modification[constraint_id][0]+sentence.strip())
                            elif i == len(paragraph)-1:
                                new_paragraph_split.append(sentence.strip()+self.modification[constraint_id][1])
                            else:
                                new_paragraph_split.append(sentence.strip())
                    else:
                        new_paragraph_split.append(self.modification[constraint_id][0]+paragraph[0].strip()+self.modification[constraint_id][1])
                    
                    new_answer_split.append(new_paragraph_split)
                
                else:
                    new_answer_split.append(paragraph)
            
            concatenated_paragraphs = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs

            if len(selected_paragraph)==1:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][1]
            else:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][0].replace("{num}", str(len(selected_paragraph)))
               

    def decoration_paragraph_symbol(self, constraint_id, instruction_id):

        if constraint_id == "start_self":
            word_split = self.answer.split()
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[0])
        
        if constraint_id == "end_self":
            word_split = self.answer.split()
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[-2])

        if constraint_id == "paragraph_start_self":
            selected_paragraph = return_rand_list([i for i in range(0, len(self.answer_split))], 1)[0]
            word_split = self.answer_split[selected_paragraph][0].split()
            if selected_paragraph!=len(self.answer_split)-1:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[0]).replace("{num}", random.choice(self.sequence_dict[selected_paragraph]))
            else:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[0]).replace("{num}", "last")

        if constraint_id == "paragraph_end_self":
            selected_paragraph = return_rand_list([i for i in range(0, len(self.answer_split))], 1)[0]
            word_split = self.answer_split[selected_paragraph][-1].split()
            if selected_paragraph!=len(self.answer_split)-1:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[-1]).replace("{num}", random.choice(self.sequence_dict[selected_paragraph]))
            else:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", word_split[-1]).replace("{num}", "last")

        if constraint_id == "paragraph_start":
            new_answer_split = []

            for paragraph in self.answer_split:
                new_paragraph_split = []
                for i, sentence in enumerate(paragraph):
                    if i == 0:
                        new_paragraph_split.append(self.modification[constraint_id][0]+sentence.strip())
                    else:
                        new_paragraph_split.append(sentence.strip())
                new_answer_split.append(new_paragraph_split)
            
            concatenated_paragraphs = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs
        
        if constraint_id == "paragraph_end":
            new_answer_split = []

            for paragraph in self.answer_split:
                new_paragraph_split = []
                for i, sentence in enumerate(paragraph):
                    if i == len(paragraph)-1:
                        new_paragraph_split.append(sentence.strip()+self.modification[constraint_id][0])
                    else:
                        new_paragraph_split.append(sentence.strip())
                new_answer_split.append(new_paragraph_split)
            
            concatenated_paragraphs = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs
        
        if constraint_id == "separation":
            new_answer_split = []

            for j, paragraph in enumerate(self.answer_split):
                new_paragraph_split = []
                if j == len(self.answer_split)-1:
                    for sentence in paragraph:
                        new_paragraph_split.append(sentence.strip())
                else:
                    for i, sentence in enumerate(paragraph):
                        if i == len(paragraph)-1:
                            new_paragraph_split.append(sentence.strip()+"\n"+self.modification[constraint_id][0])
                        else:
                            new_paragraph_split.append(sentence.strip())
                new_answer_split.append(new_paragraph_split)
            
            concatenated_paragraphs = "\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs
        
        if constraint_id == "paragraph_enclose":
            new_answer_split = []

            for paragraph in self.answer_split:
                new_paragraph_split = []
                if len(paragraph)>1:
                    for i, sentence in enumerate(paragraph):
                        if i == 0:
                            new_paragraph_split.append(self.modification[constraint_id][0]+sentence.strip())
                        elif i == len(paragraph)-1:
                            new_paragraph_split.append(sentence.strip()+self.modification[constraint_id][1])
                        else:
                            new_paragraph_split.append(sentence.strip())
                else:
                    new_paragraph_split.append(self.modification[constraint_id][0]+paragraph[0].strip()+self.modification[constraint_id][1])
                
                new_answer_split.append(new_paragraph_split)
            
            concatenated_paragraphs = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs

        if constraint_id == "paragraph_label":
            new_answer_split = []
            alpha_list = list(string.ascii_uppercase)

            for paragraph,alpha in zip(self.answer_split, alpha_list):
                new_paragraph_split = []
                for i, sentence in enumerate(paragraph):
                    if i==0:
                        new_paragraph_split.append(f"{self.modification[constraint_id][0]}{alpha}{self.modification[constraint_id][1]} {sentence.strip()}")
                    else:
                        new_paragraph_split.append(sentence.strip())
                new_answer_split.append(new_paragraph_split)

            concatenated_paragraphs = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs
        
        if constraint_id == "paragraph_label_num":
            new_answer_split = []

            for j,paragraph in enumerate(self.answer_split):
                new_paragraph_split = []
                for i, sentence in enumerate(paragraph):
                    if i==0:
                        new_paragraph_split.append(f"{self.modification[constraint_id][0]}{j+1}{self.modification[constraint_id][1]} {sentence.strip()}")
                    else:
                        new_paragraph_split.append(sentence.strip())
                new_answer_split.append(new_paragraph_split)

            concatenated_paragraphs = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs
            
        if constraint_id == "paragraph_title_enclose":
            new_answer_split = []
            for i, paragraph in enumerate(self.answer_split):
                current_paragraph = " ".join(paragraph)
                paragraph = self.modification[constraint_id][0]+self.title[i].strip()+self.modification[constraint_id][1]+"\n\n"+current_paragraph
                new_answer_split.append(paragraph)

            concatenated_paragraphs = "\n\n".join(new_answer_split)
            self.answer = concatenated_paragraphs
        
        if constraint_id == "paragraph_title":
            
            new_answer_split = []
            for i, paragraph in enumerate(self.answer_split):
                current_paragraph = " ".join(paragraph)
                paragraph = self.title[i].strip()+"\n\n"+current_paragraph
                new_answer_split.append(paragraph)

            concatenated_paragraphs = "\n\n".join(new_answer_split)
            self.answer = concatenated_paragraphs

        if constraint_id == "paragraph_enclose_random":
            new_answer_split = []
            total_paragraph = len(self.answer_split)
            selected_paragraph = return_rand_list([i for i in range(total_paragraph)],random.choice([i for i in range(1,total_paragraph)]))

            for x, paragraph in enumerate(self.answer_split):
                new_paragraph_split = []
                if x in selected_paragraph:
                    if len(paragraph)>1:
                        for i, sentence in enumerate(paragraph):
                            if i == 0:
                                new_paragraph_split.append(self.modification[constraint_id][0]+sentence.strip())
                            elif i == len(paragraph)-1:
                                new_paragraph_split.append(sentence.strip()+self.modification[constraint_id][1])
                            else:
                                new_paragraph_split.append(sentence.strip())
                    else:
                        new_paragraph_split.append(self.modification[constraint_id][0]+paragraph[0].strip()+self.modification[constraint_id][1])
                    
                    new_answer_split.append(new_paragraph_split)
                
                else:
                    new_answer_split.append(paragraph)
            
            concatenated_paragraphs = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            self.answer_split = new_answer_split
            self.answer = concatenated_paragraphs

            if len(selected_paragraph)==1:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][1]
            else:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][0].replace("{num}", str(len(selected_paragraph)))
               
    
    def keyword(self, constraint_id, instruction_id):
        if constraint_id == "keywords":
            # remove keywords if number is in string

            if self.keywords:
                selected_keywords = self.keywords
                selected_keywords = return_rand_list(selected_keywords,return_rand_list([i for i in range(1,len(selected_keywords)+1)],1)[0])

                if len(selected_keywords)>3:
                    selected_keywords = selected_keywords[:3]

                keywords = ""
                for i, keyword in enumerate(selected_keywords):
                    if i == len(selected_keywords)-1:
                        keywords += f"\"{keyword}\""
                    else:
                        keywords += f"\"{keyword}\", "
                if len(selected_keywords)>1:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][0].replace("{words}", keywords)
                else:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][1].replace("{words}", keywords)

            else:
                self.constraint_type["6_keywords"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"
        
        if constraint_id == "keywords_forbidden":

            if self.keywords:
                selected_keywords = self.keywords
                selected_keywords = return_rand_list(selected_keywords,return_rand_list([i for i in range(1,len(selected_keywords)+1)],1)[0])

                if len(selected_keywords)>3:
                    selected_keywords = selected_keywords[:3]
                    
                keywords_input = "\n".join([str(i+1)+". "+s for i, s in enumerate(selected_keywords)])
                keywords = ""
                query = forbidden_keywords_prompt.format(keywords_input = keywords_input)
                system_prompt = "You are a linguistics expert specializing in English literature."
                output = vllm_reader(self.model, self.tokenizer, self.params, query, system_prompt)
                output = output.strip()
                new_keywords = extract_phrases(output)

                # check if the new keywords are in the text.
                new_keywords = [s.lower().strip() for s in new_keywords if s not in self.answer]
                for i, keyword in enumerate(new_keywords):
                    if i == len(new_keywords)-1:
                        keywords += f"\"{keyword}\""
                    else:
                        keywords += f"\"{keyword}\", "
                if len(new_keywords)>1:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][0].replace("{words}", keywords)
                else:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][1].replace("{words}", keywords)
            
            else:
                self.constraint_type["6_keywords"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"
        
        if constraint_id == "keywords_frequency":

            if self.keywords:
                keywords_frequency = {s.lower(): 0 for s in self.keywords}
                selected_keywords = self.keywords
                # count the occurrence of each keyword in the text
                word_split = self.answer.lower().split()
                for key, value in keywords_frequency.items():
                    keywords_frequency[key] = word_split.count(key)
                
                # sort the keywords by frequency
                sorted_keywords = sorted(keywords_frequency, key=keywords_frequency.get, reverse=True)

                if keywords_frequency[sorted_keywords[0]]> 1:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][0].replace("{words}", f"\"{sorted_keywords[0]}\"").replace("{num}", f"{str(keywords_frequency[sorted_keywords[0]])}")
                else:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][1].replace("{words}", f"\"{sorted_keywords[0]}\"")
            
            else:
                self.constraint_type["6_keywords"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"
                
              
    def punctuation(self, constraint_id, instruction_id):
        
        if constraint_id == "end":
            new_answer_split = []
            new_answer = ""
            selected_period = return_rand_list(["!","?"], 1)[0]

            # If there's nothing in self.answer_split, safely return an empty result
            if not self.answer_split:
                self.answer_split = []
                self.answer = ""

            else:
                # Check if we have a "list of lists" (paragraphs) or just a "list" (sentences)
                if isinstance(self.answer_split[0], list):
                    # We have paragraphs
                    for paragraph in self.answer_split:
                        cleaned_paragraph = []
                        for sentence in paragraph:
                            # 1) Replace '.' when it's immediately followed by a newline => '!'
                            sentence = re.sub(r'\.(?=\n)', selected_period, sentence)
                            
                            # 2) Replace one or more '.' at the end of the sentence (plus optional spaces) => '!'
                            cleaned_sentence = re.sub(r'\.+\s*$', selected_period, sentence.strip())
                            
                            cleaned_paragraph.append(cleaned_sentence)
                        new_answer_split.append(cleaned_paragraph)
                    
                    # Join each paragraph’s sentences, then join paragraphs with double newlines
                    new_answer = "\n\n".join(" ".join(paragraph) for paragraph in new_answer_split)
                
                else:
                    # We have a simple list of sentences
                    for sentence in self.answer_split:
                        sentence = re.sub(r'\.(?=\n)', selected_period, sentence)
                        cleaned_sentence = re.sub(r'\.+\s*$', selected_period, sentence.strip())
                        new_answer_split.append(cleaned_sentence)
                    
                    # Join all sentences into one big string
                    new_answer = " ".join(new_answer_split)

                self.answer_split = new_answer_split
                self.answer = new_answer
                if selected_period == "!":
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", "exclamation")
                else:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{words}", "question")

        if constraint_id == "comma":
            
            new_answer_split = []
            new_answer = ""

            if type(self.answer_split[0]) == list:
                new_answer_split = []
                for paragraph in self.answer_split:
                    cleaned_paragraph = []
                    for sentence in paragraph:
                        # Remove one or more dots at the end of the sentence (with optional whitespace)
                        cleaned_sentence = re.sub(r',', '', sentence.strip())
                        cleaned_paragraph.append(cleaned_sentence.strip())
                    new_answer_split.append(cleaned_paragraph)
                new_answer = "\n\n".join([" ".join(paragraph) for paragraph in new_answer_split])
            
            else:
                new_answer_split = []
                for sentence in self.answer_split:
                    # Remove one or more dots at the end of the sentence (with optional whitespace)
                    cleaned_sentence = re.sub(r',', '', sentence.strip())
                    new_answer_split.append(cleaned_sentence.strip())
                new_answer = " ".join(new_answer_split)
            
            self.answer_split = new_answer_split
            self.answer = new_answer
        
        if constraint_id == "comma_frequency":
            comma_counts = self.answer.count(",")
            if comma_counts > 0:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(comma_counts))
            else:
                self.constraint_type["7_punctuation"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"

    def structure(self, constraint_id, instruction_id):
        
        # initialize the counts
        paragraph_count = 0
        sentence_count = 0

        # word count
        doc = self.nlp_tool(self.answer)
        word_list = [token.text for token in doc if token.is_alpha or token.like_num]
        word_count = len(word_list)

        if type(self.answer_split[0]) == list:
            for paragraph in self.answer_split:
                paragraph_count += 1
                for sentence in paragraph:
                    sentence_count += 1

        else:
            for sentence in self.answer_split:
                sentence_count += 1

        self.paragraphs_count = paragraph_count
        self.sentences_count = sentence_count
        self.words_count = word_count
        
        if constraint_id == "length_constraint_less":
            new_word_count = 10
            if self.words_count < 50:
                if self.words_count > 10:
                    new_word_count = (int(self.words_count/10)+1)*10
            else:
                new_word_count = (int(self.words_count/50)+1)*50

            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(new_word_count))
        
        if constraint_id == "length_constraint_more":
            new_word_count = self.words_count
            if self.words_count < 50:
                if self.words_count > 10:
                    new_word_count = (int(self.words_count/10))*10

            else:
                new_word_count = (int(self.words_count/50))*50

            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(new_word_count))
        
        if constraint_id == "length_constraint_exact":
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(self.words_count))
        
        if constraint_id == "paragraph_less":
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(self.paragraphs_count+random.choice([1,2])))
        
        if constraint_id == "paragraph_more":
            if self.paragraphs_count >= 3 :
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(self.paragraphs_count-random.choice([1,2])))
            else:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(self.paragraphs_count))
        
        if constraint_id == "paragraph_exact":
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(self.paragraphs_count))
        
        if constraint_id == "sentence_less":
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(self.sentences_count+random.choice([1,2,3,4,5])))
        
        if constraint_id == "sentence_more":
            # think about whether to incorporate singular constraints            
            if self.sentences_count >=6:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][0].replace("{num}", str(self.sentences_count-random.choice([0,1,2,3,4,5])))
            else:
                if self.sentences_count == 1:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][1]
                else:
                    self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][0].replace("{num}", str(self.sentences_count))
        
        if constraint_id == "sentence_exact":
            if self.sentence_count>1:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][0].replace("{num}", str(self.sentences_count))
            else:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id][1]
        
        if constraint_id == "sentence_less_each":
            highest_sentence_count = 0
            for paragraph in self.answer_split:
                if len(paragraph)>highest_sentence_count:
                    highest_sentence_count = len(paragraph)

            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(highest_sentence_count+random.choice([1,2,3,4,5])))
        
        if constraint_id == "sentence_more_each":
            lowest_sentence_count = 10000
            for paragraph in self.answer_split:
                if len(paragraph)<lowest_sentence_count:
                    lowest_sentence_count = len(paragraph)
            if self.sentences_count >=6:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(lowest_sentence_count-random.choice([0,1,2,3,4,5])))
            else:
                self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(lowest_sentence_count))

        if constraint_id == "sentence_word_less":
            highest_word_count = 0
            for sentence in self.answer_split:
                doc = self.nlp_tool(sentence)
                word_list = [token.text for token in doc if token.is_alpha or token.like_num]
                word_count = len(word_list)
                if word_count>highest_word_count:
                    highest_word_count = word_count
            
            new_word_count = 10
            if highest_word_count < 50:
                if highest_word_count > 10:
                    new_word_count = (int(highest_word_count/10)+1)*10
            else:
                new_word_count = (int(highest_word_count/50)+1)*50

            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(new_word_count))

            if self.sentence_count== 1:
                self.constraint_type["8_structure"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"
        
        if constraint_id == "sentence_word_more":
            lowest_word_count = 10000
            for sentence in self.answer_split:
                doc = self.nlp_tool(sentence)
                word_list = [token.text for token in doc if token.is_alpha or token.like_num]
                word_count = len(word_list)

                if word_count<=lowest_word_count:
                    lowest_word_count = word_count
            
            new_word_count = lowest_word_count
            if lowest_word_count < 50:
                if lowest_word_count > 10:
                    new_word_count = (int(lowest_word_count/10))*10

            else:
                new_word_count = (int(lowest_word_count/50))*50

            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(lowest_word_count-random.choice([i for i in range(1,2)])))
            
            if self.sentence_count == 1:
                self.constraint_type["8_structure"] = "ERROR_404"
                self.constraint_instruction[instruction_id] = "ERROR_404"

        if constraint_id == "sentence_specific":
            selected_paragraph = return_rand_list([i for i in range(0, len(self.answer_split))], 1)[0]
            selected_sentence_count = len(self.answer_split[selected_paragraph])

            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num_1}", random.choice(self.sequence_dict[selected_paragraph])).replace("{num_2}", str(selected_sentence_count))
        
        if constraint_id == "sentence_word_specific":
            selected_sentence = return_rand_list([i for i in range(0, len(self.answer_split))], 1)[0]
            doc = self.nlp_tool(self.answer_split[selected_sentence])
            word_list = [token.text for token in doc if token.is_alpha or token.like_num]
            word_count = len(word_list)
            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num_1}", random.choice(self.sequence_dict[selected_sentence])).replace("{num_2}", str(word_count))
     
        if constraint_id == "paragraph_word_less":
            highest_word_count = 0
            for paragraph in self.answer_split:
                doc = self.nlp_tool(" ".join(paragraph))
                word_list = [token.text for token in doc if token.is_alpha or token.like_num]
                word_count = len(word_list)
                if word_count>highest_word_count:
                    highest_word_count = word_count
            
            new_word_count = 10
            if highest_word_count < 50:
                if highest_word_count > 10:
                    new_word_count = (int(highest_word_count/10)+1)*10
            else:
                new_word_count = (int(highest_word_count/50)+1)*50


            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(new_word_count))
        
        if constraint_id == "paragraph_word_more":
            lowest_word_count = 10000
            for paragraph in self.answer_split:
                doc = self.nlp_tool(" ".join(paragraph))
                word_list = [token.text for token in doc if token.is_alpha or token.like_num]
                word_count = len(word_list)
                if word_count<=lowest_word_count:
                    lowest_word_count = word_count

            new_word_count = lowest_word_count
            if lowest_word_count < 50:
                if lowest_word_count > 10:
                    new_word_count = (int(lowest_word_count/10))*10
            else:
                new_word_count = (int(lowest_word_count/50))*50

            self.constraint_instruction[instruction_id] = self.constraint_instruction[instruction_id].replace("{num}", str(new_word_count))