import re
from langchain_openai.embeddings import OpenAIEmbeddings
import os
import matplotlib.pyplot as plt
import networkx as nx
import plotly.graph_objs as go
from plotly.offline import plot
import base64
import hashlib
import logging
from io import BytesIO
from datetime import datetime
from PIL import Image

# Set work dir to root
current_path = os.path.dirname(os.path.abspath(__file__))
parent_path = os.path.dirname(current_path)
os.chdir(parent_path)

SEED = 1024

#       
def image_to_base64(image_path):
    
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def extract_tag_content(text: str, label: str) -> str:
   
    #              <label>                 </label>        
    pattern = rf'<{label}>\n?(.*?)\n?</{label}>'
    
    #    re.DOTALL   .               
    match = re.search(pattern, text, re.DOTALL)
    
    if match:
        content = match.group(1)
        #               
        content = content.lstrip('\n').rstrip('\n')
        return content
    else:
        return ''

def yaml_decoder(output : str) -> tuple:
       
        return decode_md(output, "YAML")[-1]
    
def decode_cot(text: str) -> str:

          
    processed_text = text
    
                           
    pattern = r'(\n|^)([^\n]*?:\s*)\n```.*?```'
    
     
    processed_text = re.sub(
        pattern,
        r'',
        processed_text,
        flags=re.DOTALL
    )
    
            
    processed_text = re.sub(
        r'```.*?```',
        r'',
        processed_text,
        flags=re.DOTALL
    )
    
                 
    fragments = [
        segment.strip()
        for segment in re.split(
            r'\n\s*\n+',
            processed_text
        )
        if segment.strip()
    ]
    
        
    return '\n\n'.join(fragments)

def decode_md(
        output : str,
        code_type : str
    ) -> list[str]:
     
        # Define a regex pattern to extract code blocks of a specific type
        pattern = rf'```(?:{code_type}\s*)?(.*?)\s*```'
        
        # Find all matches in the text
        matches = re.findall(pattern, output, re.DOTALL| re.IGNORECASE)
        
        # Return a list of all code blocks of the specified type
        return matches

def extract_all_code_blocks(text : str) -> list[str]:
    # Pattern to match text surrounded by triple backticks
    pattern = r'```(.*?)```'
    
    # Extract all non-overlapping matches of the pattern
    code_blocks = re.findall(pattern, text, flags=re.DOTALL)
    
    return code_blocks

def normalize_string(string : str) -> str:
  
    return string.strip().casefold()

class ModelConfig:
    def __init__(
        self,
        model : str,
        base_url : str,
        temperature=0.1,
        top_p=0.7,
        max_tokens = 4096,
        seed = SEED,
        api_key : str = " "
    ) -> None:
   
        self.model = model
        self.base_url = base_url
        self.temperature = temperature
        self.top_p = top_p
        self.max_tokens = max_tokens
        self.api_key = api_key
        self.seed = seed

class VectorDBConfig:
    def __init__(
        self,
        collection_name : str = "VectorBatabase",
        persist_directory : str = "./database/vectorDB",
    ) -> None:
        self.collection_name = collection_name
        self.persist_directory = persist_directory

def merge_images_horizontally(base64_image1, base64_image2):
    
    try:
        #   base64        
        img1_data = base64.b64decode(base64_image1)
        img2_data = base64.b64decode(base64_image2)
        
        #         RGB             
        img1 = Image.open(BytesIO(img1_data)).convert('RGB')
        img2 = Image.open(BytesIO(img2_data)).convert('RGB')
        
        #       
        width1, height1 = img1.size
        width2, height2 = img2.size
        
        #                        
        new_width = width1 + width2
        new_height = max(height1, height2)
        new_img = Image.new('RGB', (new_width, new_height), (255, 255, 255))  #     
        
        #             
        new_img.paste(img1, (0, (new_height - height1) // 2))  #         
        new_img.paste(img2, (width1, (new_height - height2) // 2))  #         
        
        #    JPEG      base64
        buffered = BytesIO()
        new_img.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode()
        
    except Exception as e:
        return f"Error: {str(e)}"

def generate_image_filename(base64_bytes, file_type : str = 'jpg'):
    # decode b64 image
    image_data = base64.b64decode(base64_bytes)
    
    file_name = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # jpg file
    file_name_with_extension = file_name + f'.{file_type}'
    
    return file_name_with_extension, image_data

def save_image_from_base64(base64_bytes, save_path, file_type: str = 'jpg') -> str:
    try:
        # Ensure save_path exists
        if not os.path.exists(save_path):
            logging.error(f"Save path does not exist: {save_path}")
            raise ValueError(f"Invalid save path: {save_path}")
        
        # Generate filename and image data
        file_name, image_data = generate_image_filename(base64_bytes, file_type=file_type)
        
        # Combine full file path
        full_path = os.path.join(save_path, file_name)
        
        # Write image to file
        with open(full_path, 'wb') as f:
            f.write(image_data)
        
        logging.info(f"Image saved to {full_path}")
        return full_path
    except Exception as e:
        logging.error(f"Error saving image from base64: {e}")
        raise ValueError(f"Error saving image from base64: {e}")
        

## Optional Model Config

