import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ast
import math

########## text cleaning functions ##########
def clean_text(text):
    items = []
    
    # if text is in a list format
    if isinstance(text, str) and text.startswith('[') and text.endswith(']'):
        try:
            text_list = ast.literal_eval(text)

            for item in text_list:
                cleaned_item = clean_individual_text(item)
                if cleaned_item: 
                    items.append(cleaned_item)
        except (SyntaxError, ValueError):
            cleaned = clean_individual_text(text)
            if cleaned:
                items.append(cleaned)
    else:
        cleaned = clean_individual_text(text)
        if cleaned:
            items.append(cleaned)
            
    return "\n\n".join(items)

def clean_individual_text(text):
    # replace literal '\n' with actual newlines
    text = text.replace('\\n', '\n')
    
    # cases where the actual content starts after a newline at the beginning
    if '\n' in text and not text.startswith('Solution:'):
        # split by newline and check if first part is short
        parts = text.split('\n', 1)
        if len(parts) > 1 and len(parts[0]) < 30:  # assuming actual content is after short intro
            text = parts[1]
    
    # remove 'Solution:' prefix 
    text = re.sub(r'(?:^|\n)Solution:\s*', '', text)
    
    # remove HTML tags
    text = re.sub(r'<[^>]*>', '', text)
    
    # remove markdown-related formatting
    text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)  # bold
    text = re.sub(r'\*(.*?)\*', r'\1', text)      # italic
    text = re.sub(r'__(.*?)__', r'\1', text)      # underline
    text = re.sub(r'_(.*?)_', r'\1', text)        # italic

    # remove multiple spaces
    text = re.sub(r' +', ' ', text)
    
    # remove extra blank lines (more than 2 consecutive)
    text = re.sub(r'\n{3,}', '\n\n', text)
    
    # strip whitespaces
    text = text.strip()
    
    return text


########## extract LabelStudio codes, sometimes nested ##########
LABEL_PAT = re.compile(r"[A-Z]{2,}\d*: [A-Za-z ]+")

def extract_codes(obj):

    def rec(x):
        if x is None or (isinstance(x, float) and math.isnan(x)):
            return []

        if isinstance(x, str):
            s = x.strip()
            if s and s[0] in "[{(":
                try:
                    parsed = ast.literal_eval(s)
                except Exception:
                    parsed = None
                if parsed is not None and not isinstance(parsed, str):
                    return rec(parsed)
            return LABEL_PAT.findall(s)

        if isinstance(x, dict):
            out = []
            for k, v in x.items():
                out.extend(rec(k))
                out.extend(rec(v))
            return out

        if isinstance(x, (list, tuple, set)):
            out = []
            for v in x:
                out.extend(rec(v))
            return out

        return []

    seen, uniq = set(), []
    for label in rec(obj):
        if label not in seen:
            seen.add(label)
            uniq.append(label)
    return uniq

############ other utils ##########

def get_category(label: str) -> str:
    """ Returns the theme prefix (e.g., IU, IQ, SQ) from a label like 'IU1: Density'. """
    prefix = label.split(':', 1)[0]          # 'IU1'
    match = re.match(r'[A-Za-z]+', prefix)   # 'IU'
    if match:
        return match.group().upper()
    raise ValueError(f"Unrecognized label format: {label}")