from abc import ABC, abstractmethod
import json
import os
import sys
try:
    from openai import OpenAI
    OPENAI_AVAILABLE = True
except ImportError:
    OPENAI_AVAILABLE = False
    print("Warning: OpenAI package not available")

try:
    import tiktoken
    TIKTOKEN_AVAILABLE = True
except ImportError:
    TIKTOKEN_AVAILABLE = False
    print("Warning: tiktoken package not available")

project_root_path = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)

if project_root_path not in sys.path:
    sys.path.insert(0, project_root_path)


def chat_template(messages):
    """
    Convert messages list to a string that conforms to Chat template format
    Used for calculating token count.
    """
    formatted = ""
    for msg in messages:
        role = msg["role"]
        content = msg["content"]
        formatted += f"<|{role}|>\n{content}\n"
    formatted += "<|assistant|>\n"
    return formatted


def merge_repeated_role(messages):
    ptr = len(messages) - 1
    last_role = ""
    while ptr >= 0:
        cur_role = messages[ptr]["role"]
        if cur_role == last_role:
            messages[ptr]["content"] += "\n" + messages[ptr + 1]["content"]
            del messages[ptr + 1]
        last_role = cur_role
        ptr -= 1
    return messages


def repair_json_simple(json_str):
    """Simple JSON repair function"""
    try:
        return json.loads(json_str)
    except json.JSONDecodeError:
        # Try to fix common issues
        json_str = json_str.strip()
        if not json_str.startswith('{') and not json_str.startswith('['):
            json_str = '{"content": "' + json_str.replace('"', '\\"') + '"}'
        try:
            return json.loads(json_str)
        except:
            return {"raw_response": json_str, "error": "JSON parsing failed"}


class AbstractLLM(ABC):
    class ModeError(Exception):
        pass

    def __init__(self):
        self.input_token_count = 0
        self.output_token_count = 0
        self.input_token_maxx = 0
        pass

    def __call__(self, messages, one_line=True, json_mode=False):
        if one_line and json_mode:
            raise self.ModeError(
                "one_line and json_mode cannot be True at the same time"
            )
        return self._get_response(messages, one_line, json_mode)

    @abstractmethod
    def _get_response(self, messages, one_line, json_mode):
        pass


class Deepseek(AbstractLLM):
    def __init__(self):
        super().__init__()
        self.model_name = "deepseek-chat"
        self.api_url = self._load_config()[self.model_name]["url"]
        self.api_key = self._load_config()[self.model_name]["key"]
        
        # Try to import requests
        try:
            import requests
            self.requests = requests
        except ImportError:
            print("Warning: requests not available, deepseek API calls will fail")
            self.requests = None
    
    def _load_config(self):
        with open('config/llms_config.json', "r") as f:
            return json.load(f)

    def _count_tokens(self, text: str) -> int:
        """Approximate token counting for Deepseek"""
        return len(text) // 4  # Rough approximation
    
    def _get_response(self, messages, one_line, json_mode):
        """Get response from Deepseek API"""
        if not self.requests:
            return "Error: requests library not available"
        
        try:
            headers = {
                'Authorization': f'Bearer {self.api_key}',
                'Content-Type': 'application/json'
            }
            
            data = {
                "model": self.model_name,
                "messages": messages,
                "temperature": 0.6,
                "max_tokens": 10000
            }
            
            response = self.requests.post(self.api_url, headers=headers, json=data, timeout=600)
            
            if response.status_code != 200:
                error_msg = f"Deepseek API error: HTTP {response.status_code} - {response.text}"
                print(error_msg)
                return error_msg
            
            result = response.json()
            content = result['choices'][0]['message']['content']
            
            # Simple token counting
            text = " ".join([msg["content"] for msg in messages])
            input_tokens = len(text.split()) * 1.3
            output_tokens = len(content.split()) * 1.3
            
            self.input_token_count += int(input_tokens)
            self.output_token_count += int(output_tokens)
            self.input_token_maxx = max(self.input_token_maxx, int(input_tokens))
            
            if json_mode:
                try:
                    # Try to parse as JSON and return as JSON string
                    parsed = json.loads(content)
                    return json.dumps(parsed, ensure_ascii=False)
                except json.JSONDecodeError:
                    # If not valid JSON, try to repair it
                    repaired = repair_json_simple(content)
                    return repaired
            
            return content
            
        except Exception as e:
            error_msg = f"Deepseek API error: {e}"
            print(error_msg)
            return error_msg


class GPT4o(AbstractLLM):
    def __init__(self, max_model_len=None):
        super().__init__()
        self.model_name = "796-gpt-4o__2024-11-20"
        self.api_url = self._load_config()[self.model_name]["url"]
        self.api_key = self._load_config()[self.model_name]["key"]
        self.max_model_len = max_model_len   
        # Try to import requests
        try:
            import requests
            self.requests = requests
        except ImportError:
            print("Warning: requests not available, GPT4O API calls will fail")
            self.requests = None

    def _load_config(self):
        with open('config/llms_config.json', "r") as f:
            return json.load(f)

    def _count_tokens(self, text: str) -> int:
        """Approximate token counting for GPT4O"""
        return len(text) // 4  # Rough approximation

    def _get_response(self, messages, one_line, json_mode):
        if isinstance(messages, str):
            messages = [{"role": "user", "content": messages}]
        if not self.requests:
            return "Error: requests library not available"
        
        try:
            headers = {
                'Authorization': f'Bearer {self.api_key}',
                'Content-Type': 'application/json'
            }
            
            data = {
                "model": self.model_name,
                "messages": messages,
                "temperature": 0.6,
                "max_tokens": self.max_model_len if self.max_model_len else 10000
            }
            
            response = self.requests.post(self.api_url, headers=headers, json=data, timeout=6000)
            
            if response.status_code != 200:
                error_msg = f"GPT4o API error: HTTP {response.status_code} - {response.text}"
                print(error_msg)
                return error_msg
            
            result = response.json()
            content = result['choices'][0]['message']['content']
            
            # Simple token counting
            text = " ".join([msg["content"] for msg in messages])
            input_tokens = len(text.split()) * 1.3
            output_tokens = len(content.split()) * 1.3
            
            self.input_token_count += int(input_tokens)
            self.output_token_count += int(output_tokens)
            self.input_token_maxx = max(self.input_token_maxx, int(input_tokens))
            
            if json_mode:
                try:
                    # Try to parse as JSON and return as JSON string
                    parsed = json.loads(content)
                    return json.dumps(parsed, ensure_ascii=False)
                except json.JSONDecodeError:
                    # If not valid JSON, try to repair it
                    repaired = repair_json_simple(content)
                    return repaired
            
            return content
            
        except Exception as e:
            error_msg = f"GPT4o API error: {e}"
            print(error_msg)
            return error_msg


class EmptyLLM(AbstractLLM):
    def __init__(self):
        super().__init__()
        self.name = "EmptyLLM"

    def _get_response(self, messages, one_line, json_mode):
        if json_mode:
            return '{"response": "Empty LLM response", "method": "placeholder"}'
        return "Empty LLM response"


class Gemini(AbstractLLM):
    """Gemini API LLM implementation"""
    
    def __init__(self, max_model_len=None):
        super().__init__()
        self.model_name = "gemini-2.5-flash-latest"
        self.api_url = self._load_config()[self.model_name]["url"]
        self.api_key = self._load_config()[self.model_name]["key"]
        self.max_model_len = max_model_len
        # Try to import requests
        try:
            import requests
            self.requests = requests
        except ImportError:
            print("Warning: requests not available, Gemini API calls will fail")
            self.requests = None
    
    def _load_config(self):
        with open('config/llms_config.json', "r") as f:
            return json.load(f)
    
    def _count_tokens(self, text: str) -> int:
        """Approximate token counting for Gemini"""
        return len(text) // 4  # Rough approximation
    
    def _get_response(self, messages, one_line, json_mode):
        """Get response from Gemini API"""
        if not self.requests:
            return "Error: requests library not available"
        
        try:
            headers = {
                'Authorization': f'Bearer {self.api_key}',
                'Content-Type': 'application/json'
            }
            
            data = {
                "model": self.model_name,
                "messages": messages,
                "temperature": 0.6,
                "max_tokens": self.max_model_len if self.max_model_len else 10000
            }
            
            response = self.requests.post(self.api_url, headers=headers, json=data, timeout=6000)
            
            if response.status_code != 200:
                error_msg = f"Gemini API error: HTTP {response.status_code} - {response.text}"
                print(error_msg)
                return error_msg
            
            result = response.json()
            content = result['choices'][0]['message']['content']
            
            # Simple token counting
            text = " ".join([msg["content"] for msg in messages])
            input_tokens = len(text.split()) * 1.3
            output_tokens = len(content.split()) * 1.3
            
            self.input_token_count += int(input_tokens)
            self.output_token_count += int(output_tokens)
            self.input_token_maxx = max(self.input_token_maxx, int(input_tokens))
            
            if json_mode:
                try:
                    # Try to parse as JSON and return as JSON string
                    parsed = json.loads(content)
                    return json.dumps(parsed, ensure_ascii=False)
                except json.JSONDecodeError:
                    # If not valid JSON, try to repair it
                    repaired = repair_json_simple(content)
                    return repaired
            
            return content
            
        except Exception as e:
            error_msg = f"Gemini API error: {e}"
            print(error_msg)
            return error_msg


class Qwen3(AbstractLLM):
    """Gemini API LLM implementation"""
    
    def __init__(self, model_name=None, max_model_len=None):
        super().__init__() 
        self.model_name = model_name
        self.api_url = self._load_config()[model_name]["url"]
        self.api_key = self._load_config()[model_name]["key"]
        self.max_model_len = max_model_len
        # Try to import requests
        try:
            import requests
            self.requests = requests
        except ImportError:
            print("Warning: requests not available, Qwen3 API calls will fail")
            self.requests = None

    def _load_config(self):
        with open('config/llms_config.json', "r") as f:
            return json.load(f)
    
    def _count_tokens(self, text: str) -> int:
        """Approximate token counting for Qwen3"""
        return len(text) // 4  # Rough approximation
    
    def _get_response(self, messages, one_line, json_mode):
        """Get response from Qwen3 API"""
        if not self.requests:
            return "Error: requests library not available"
        
        try:
            headers = {
            }
            
            data = {
                "model": self.model_name,
                "messages": messages,
                "temperature": 0.6,
                "max_tokens": self.max_model_len if self.max_model_len else 5000
            }
            
            response = self.requests.post(self.api_url, headers=headers, json=data, timeout=6000)
            
            if response.status_code != 200:
                error_msg = f"Qwen3 API error: HTTP {response.status_code} - {response.text}"
                print(error_msg)
                return error_msg
            
            result = response.json()
            content = result['choices'][0]['message']['content']
            
            # Simple token counting
            text = " ".join([msg["content"] for msg in messages])
            input_tokens = len(text.split()) * 1.3
            output_tokens = len(content.split()) * 1.3
            
            self.input_token_count += int(input_tokens)
            self.output_token_count += int(output_tokens)
            self.input_token_maxx = max(self.input_token_maxx, int(input_tokens))
            
            if json_mode:
                try:
                    # Try to parse as JSON and return as JSON string
                    parsed = json.loads(content)
                    return json.dumps(parsed, ensure_ascii=False)
                except json.JSONDecodeError:
                    # If not valid JSON, try to repair it
                    repaired = repair_json_simple(content)
                    return repaired
            
            return content
            
        except Exception as e:
            error_msg = f"Qwen3 API error: {e}"
            print(error_msg)
            return error_msg

class CustomizedLLM(AbstractLLM):
    def __init__(self, max_model_len=None, model_name=None):
        super().__init__()
        self.model_name = model_name
        self.api_url = self._load_config()[model_name]["url"]
        self.api_key = self._load_config()[model_name]["key"]
        self.max_model_len = max_model_len   
        # Try to import requests
        try:
            import requests
            self.requests = requests
        except ImportError:
            print("Warning: requests not available, CustomizedLLM API calls will fail")
            self.requests = None

    def _load_config(self):
        with open('config/llms_config.json', "r") as f:
            return json.load(f)

    def _count_tokens(self, text: str) -> int:
        """Approximate token counting for CustomizedLLM"""
        return len(text) // 4  # Rough approximation

    def _get_response(self, messages, one_line, json_mode):
        if isinstance(messages, str):
            messages = [{"role": "user", "content": messages}]
        if not self.requests:
            return "Error: requests library not available"
        
        try:
            headers = {
            }
            
            data = {
                "model": self.model_name,
                "messages": messages,
                "temperature": 0.6,
                "max_tokens": self.max_model_len if self.max_model_len else 10000,
                "repetition_penalty": 1.1,
                # "chat_template_kwargs": {"enable_thinking": False}
            }
            
            response = self.requests.post(self.api_url, headers=headers, json=data, timeout=6000)
            
            if response.status_code != 200:
                error_msg = f"CustomizedLLM API error: HTTP {response.status_code} - {response.text}"
                print(error_msg)
                return error_msg
            
            result = response.json()
            content = result['choices'][0]['message']['content']
            
            # Simple token counting
            text = " ".join([msg["content"] for msg in messages])
            input_tokens = len(text.split()) * 1.3
            output_tokens = len(content.split()) * 1.3
            
            self.input_token_count += int(input_tokens)
            self.output_token_count += int(output_tokens)
            self.input_token_maxx = max(self.input_token_maxx, int(input_tokens))
            
            if json_mode:
                try:
                    # Try to parse as JSON and return as JSON string
                    parsed = json.loads(content)
                    return json.dumps(parsed, ensure_ascii=False)
                except json.JSONDecodeError:
                    # If not valid JSON, try to repair it
                    repaired = repair_json_simple(content)
                    return repaired
            
            return content
            
        except Exception as e:
            error_msg = f"CustomizedLLM API error: {e}"
            print(error_msg)
            return error_msg



if __name__ == "__main__":
    # Test the LLM classes
    try:
        model = GPT4o()
        response = model([{"role": "user", "content": "hello!"}], one_line=False)
        print("GPT4o response:", response)
    except Exception as e:
        print("GPT4o test failed:", e)
    
    try:
        model = Deepseek()
        response = model([{"role": "user", "content": "hello!"}], one_line=False)
        print("Deepseek response:", response)
    except Exception as e:
        print("Deepseek test failed:", e)
