import torch
import time
from pynvml import (
    nvmlInit,
    nvmlDeviceGetHandleByIndex,
    nvmlDeviceGetMemoryInfo,
)


def print_gpu_usage():
    nvmlInit()
    n_gpus = torch.cuda.device_count()

    print('========== GPU Utilization ==========')
    for gpu_id in range(n_gpus):
        h = nvmlDeviceGetHandleByIndex(gpu_id)
        info = nvmlDeviceGetMemoryInfo(h)
        print(f'GPU {gpu_id}')
        print(f'- Used:       {info.used / 1024 ** 3:>8.2f} B ({info.used / info.total * 100:.1f}%)')
        print(f'- Available:  {info.free / 1024 ** 3:>8.2f} B ({info.free / info.total * 100:.1f}%)')
        print(f'- Total:      {info.total / 1024 ** 3:>8.2f} B')
    print('=====================================')


def make_chat_call(client, model, message, max_tokens):

    if "gpt" in model:
        full_prompt = [{'role': 'user', 'content': message}]
    elif "claude" in model:
        full_prompt = [{'role': 'user', 'content': [{"type": "text", "text": message}]}]

    response = None
    wait_time = 5
    while response is None:
        try:
            if "gpt" in model:
                response = client.chat.completions.create(
                    model=model,
                    messages=full_prompt,
                    max_tokens=max_tokens,
                    n=1
                )
                for r in response.choices:
                    if r.finish_reason == "content_filter":
                        print("Content filter triggered, trying again")
                        response = None
            
            elif "claude" in model:
                response = client.messages.create(
                    model=model,
                    messages=full_prompt,
                    max_tokens=max_tokens,
                    temperature=1
                )
                if response.stop_reason != "end_turn":
                    print("End turn not triggered, trying again")
                    response = None
                
                # print(response.content)
            
        except Exception as e:
            print(f'Caught exception {e}.')
            # print(f'Waiting {wait_time} seconds.')
            time.sleep(wait_time)

    return response