import string
import re
from mistralai import Mistral
from groq import Groq
import random
import time
import torch
import openai
from transformers import AutoModel
from huggingface_hub import login

# Mets ton token Hugging Face ici
login("")
from tenacity import retry, wait_chain, wait_fixed
openai.api_key = "" # used my own key

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

hf_token=""


"""
base_model = "meta-llama/Llama-3.2-1B-Instruct"
adapter_model = "adlbh/Llama-3.2-1B-Instruct_ambigqa_grpo3"

tokenizer = AutoTokenizer.from_pretrained(base_model)
modelL = AutoModelForCausalLM.from_pretrained(base_model, device_map="auto")

modelL = PeftModel.from_pretrained(modelL, adapter_model)

modelL.eval()
"""



def completion_with_llama(messages):
    choices = []
    # Tokenize
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    )

    if isinstance(inputs, torch.Tensor):
        inputs = inputs.to(modelL.device)
        attention_mask = torch.ones_like(inputs).to(modelL.device)
    else:
        inputs = {k: v.to(modelL.device) for k, v in inputs.items()}
        attention_mask = inputs.get("attention_mask", torch.ones_like(inputs["input_ids"]).to(model.device))

    outputs = modelL.generate(
        input_ids=inputs if isinstance(inputs, torch.Tensor) else inputs["input_ids"],
        attention_mask=attention_mask,
        max_new_tokens=200,
        temperature=0.5,
        top_p=0.95,
        top_k=50,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )

    response = tokenizer.decode(
        outputs[0][inputs.shape[1]:] if isinstance(inputs, torch.Tensor) else outputs[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=True
    )

    choice = {
        "message": {
            "role": "assistant",
            "content": response,
        },
        "finish_reason": "stop",
    }
    choices.append(choice)
    return {"choices": choices}


groq_api_keys = [
]


retry_strategy = retry(
    wait=wait_chain(
        *[wait_fixed(1) for _ in range(3)],   # 3 essais à 1 s d’intervalle
        *[wait_fixed(2) for _ in range(2)],   # puis 2 essais à 2 s
        wait_fixed(3)                         # et enfin un essai à 3 s
    )
)







def completion_with_backoff(model ,messages,temperature,max_tokens,n) :
    if model=="lamma" :
        return completion_with_llama(
            messages=messages
            )
    if model=="o4-mini-2025-04-16" :
        return completion_with_openai(
            model=model,
            messages=messages,
            temperature=1,
            max_completion_tokens=max_tokens,
            n=n)
    elif model =="gpt-4o-mini-2024-07-18" :
        return completion_with_openai(
            model=model,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            n=n)
    elif model=="llama-4-scout-17b-16e-instruct" :      
        return  completion_with_groq (           
            model="meta-llama/"+model,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            n=n)
    elif model=="qwen3-32b" :      
        return  completion_with_groq (           
            model="qwen/"+model,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            n=n)
    elif model=="magistral-small-2506" :
        return completion_with_mistralC(            
            model=model,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            n=n)
    else :
        raise KeyError("model not found")
    

@retry_strategy
def completion_with_openai(**kwargs):
    return openai.ChatCompletion.create(**kwargs)

@retry_strategy
def get_completion_mistral(**kwargs) :
    client = Mistral(api_key="ZtDejxj7Gm4sw7VGDTtkaLkE8Sk9p31t")
    return ( client.chat.complete(**kwargs))

def completion_with_mistralC(n,**kwargs) :
    choices = []
    for i in range(n):
        time.sleep(1)
        response = get_completion_mistral(**kwargs) 
        p=0
        content = re.sub(r"<[\\/]{0,1}think>.*?</[\\/]{0,1}think>", "", response.choices[0].message.content, flags=re.DOTALL)
        while response.choices[0].finish_reason == "length" or content.count("Clarifications") > 1:
            p+=1
            if p ==6 :
                break
            if response.choices[0].finish_reason == "length" :
                print("La génération s’est arrêtée car la limite de tokens a été atteinte.")
            else :
                times = response.choices[0].message.content.count("Clarifications")
                print(f"La chaîne 'Clarifications' apparaît {times} fois.")            
            response=get_completion_mistral(**kwargs)
            content = re.sub(r"<[\\/]{0,1}think>.*?</[\\/]{0,1}think>", "", response.choices[0].message.content, flags=re.DOTALL)

    choice = {
        "message": {
            "role": response.choices[0].message.role,
            "content": content,
        },
        "finish_reason": response.choices[0].finish_reason,
    }
    choices.append(choice)
    return {"choices": choices}

@retry_strategy
def get_rep_groq(client,**kwargs) :
  return (client.chat.completions.create(**kwargs))

def completion_with_groq(n=1, **kwargs):
    choices = []
    api = random.choice(groq_api_keys)
    client = Groq(api_key=api)   
    for i in range(n):
        time.sleep(1)
        response = get_rep_groq(client,**kwargs) 
        p=0
        content = re.sub(r"<[\\/]{0,1}think>.*?</[\\/]{0,1}think>", "", response.choices[0].message.content, flags=re.DOTALL)
        while response.choices[0].finish_reason == "length" or content.count("Clarifications") > 1:
            p+=1
            if p ==6 :
                break
            if response.choices[0].finish_reason == "length" :
                print("La génération s’est arrêtée car la limite de tokens a été atteinte.")
            else :
                times = response.choices[0].message.content.count("Clarifications")
                print(f"La chaîne 'Clarifications' apparaît {times} fois.")            
            response = get_rep_groq(client,**kwargs) 
            content = re.sub(r"<[\\/]{0,1}think>.*?</[\\/]{0,1}think>", "", response.choices[0].message.content, flags=re.DOTALL)

    choice = {
        "message": {
            "role": response.choices[0].message.role,
            "content": content,
        },
        "finish_reason": response.choices[0].finish_reason,
    }
    choices.append(choice)
    return {"choices": choices}
















