import enum
from abc import ABC , abstractmethod
from typing import Dict, Union, Optional
import os, sys
from dataclasses import dataclass, field
import copy
sys.path.append("../")

import models_utils.llm.watsonx as watsonx
import models_utils.llm.rits as rits
import models_utils.llm.openai as openai
import models_utils.llm.azure_llm as azure
import models_utils.llm.gcp as gcp
import models_utils.llm.rits_restricted as rits_rest

class ModelTypes(enum.Enum):
    rits = "rits"
    watsonx = 'watsonx'
    openai = 'openai'
    azure = 'azure'
    gcp ='gcp'
    rits_rest = 'rits_rest'

@dataclass
class ModelConfig:
    name:str
    identifier:str
    timeout:int = 120
    prompt_params:Optional[Dict[str, Union[str,int]]] = field(default_factory=dict)

    def to_dict(self):
        return copy.deepcopy(self.__dict__)

MODEL_MAP:Dict[str,Dict[str,Union[str,int,ModelTypes]]]= {
'deepseek-r1':{
    "type":ModelTypes.rits_rest,
    "num_workers":5,
},
'deepseek-v3-h200':{
    "type":ModelTypes.rits,
    "num_workers":5,
},
'llama-3-3-70b-instruct':{
    "type":ModelTypes.rits,
    "num_workers":5,
},
'mistral-large':{
    "type":ModelTypes.watsonx,
    "num_workers":2,
},
'mistral-medium-2505':{
    "type":ModelTypes.watsonx,
    "num_workers":3,
},
'granite-3-3-8b-instruct':{
    "type":ModelTypes.watsonx,
    "num_workers":3,
},
'mistral-small-3-1-24b-instruct-2503':{
    "type":ModelTypes.watsonx,
    "num_workers":3,
},
'llama-4-maverick-17b-128e-instruct-fp8':{
    "type":ModelTypes.watsonx,
    "num_workers":3,
},
'qwen2-5-72b-instruct':{
    "type":ModelTypes.rits,
    "num_workers":5,
},
'mistral-small-3-1-24b-2503':{
    "type":ModelTypes.rits,
    "num_workers":5,
},
"microsoft-phi-4":{
    "type":ModelTypes.rits,
    "num_workers":5,
},
"llama-3-1-8b-instruct":{
    "type":ModelTypes.rits,
    "num_workers":5,
},
"llama-3-1-405b-instruct-fp8":{
    "type":ModelTypes.rits,
    "num_workers":5,
},
"o1":{
    "type":ModelTypes.azure,
    "num_workers":5,
},
"o3-mini":{
    "type":ModelTypes.azure,
    "num_workers":5,
},
"gpt-4.1-mini":{
    "type":ModelTypes.azure,
    "num_workers":5,
},
"gpt-4.1":{
    "type":ModelTypes.azure,
    "num_workers":5,
},
'gemini-2.0-flash':{
    "type":ModelTypes.gcp,
    'num_workers':5,
},
'gemini-1.5-pro':{
    "type":ModelTypes.gcp,
    'num_workers':5,
},
'claude-3-5-haiku':{
    "type":ModelTypes.gcp,
    'num_workers':5,
},
'claude-3-7-sonnet':{
    "type":ModelTypes.gcp,
    'num_workers':5,
}

}

class LLMType(ABC):
    def __init__(self, llm_config:ModelConfig) -> None:
        self.llm_config = llm_config

    @abstractmethod
    def get_response(self, prompt):
        pass

class LLMConfiguration:
    def __new__(cls, llm_config:ModelConfig) -> LLMType:
        data = MODEL_MAP.get(llm_config.name, None)
        if not data:
            raise NotImplementedError("Model not in the MAP")
        
        if data["type"] == ModelTypes.rits:
            return LLMRits(llm_config)
        elif data["type"] == ModelTypes.rits_rest:
            return LLMRestRits(llm_config)
        elif data["type"] == ModelTypes.watsonx:
            return LLMWatsonX(llm_config)
        elif data["type"] == ModelTypes.azure:
            return LLMAzure(llm_config)
        elif data['type'] == ModelTypes.gcp:
            return LLMGCP(llm_config)
        else:
            raise NotImplementedError("Model Type not implemented")

    def get_response(self, prompt):
        pass

class LLMAzure(LLMType):
    def __init__(self, llm_config: ModelConfig) -> None:
        self.client = azure.APICall(
            llm_config.name
        )
        super().__init__(llm_config)
     
    def get_response(self, prompt):
        return self.client.complete_response(prompt)
    
class LLMGCP(LLMType):
    def __init__(self, llm_config: ModelConfig) -> None:
        self.client = gcp.APICall(
            llm_config.identifier
        )
        super().__init__(llm_config)
    
    def get_response(self, prompt):
        return self.client.complete_response(prompt)
      
class LLMOpenAI(LLMType):
    def __init__(self, llm_config: ModelConfig) -> None:
        self.client = openai.APICall(
            llm_config.name
        )
        super().__init__(llm_config)
    
    def get_response(self, prompt):
        return self.client.complete_response(prompt)
        
class LLMRestRits(LLMType):
    def __init__(self, llm_config:ModelConfig) -> None:
        self.client = rits_rest.APICall(llm_config.name, **llm_config.to_dict())
        super().__init__(llm_config)

    def get_response(self, prompt, **kwargs):
        return self.client.complete_response(prompt, **kwargs)
    
class LLMRits(LLMType):
    def __init__(self, llm_config:ModelConfig) -> None:
        self.client = rits.APICall(llm_config.name)
        super().__init__(llm_config)

    def get_response(self, prompt):
        return self.client.complete_response(prompt)

class LLMWatsonX(LLMType):
    def __init__(self, llm_config:ModelConfig) -> None:
        self.client = watsonx.APICall(model_name=llm_config.identifier)
        super().__init__(llm_config)

    def get_response(self, prompt):
        response = self.client.complete_response(
            prompt
        )
        return response