import argparse
from http.server import BaseHTTPRequestHandler, HTTPServer
import json
import os
import subprocess
import threading
import time

import numpy as np
import requests


# Global variable to store the last captured GPU metrics
last_gpu_stats = []


def get_gpu_stats():
    global last_gpu_stats

    cmd = [
        'nvidia-smi',
        '--query-gpu=index,utilization.gpu,memory.used',
        '--format=csv,noheader,nounits'
    ]
    try:
        output = subprocess.check_output(cmd)
        output = output.decode('utf-8')
        lines = output.strip().split('\n')
        gpu_stats = []
        for line in lines:
            parts = line.strip().split(',')
            if len(parts) >= 3:
                gpu_index = int(parts[0].strip())
                utilization = int(parts[1].strip())
                memory_used = int(parts[2].strip())
                gpu_stats.append({
                    'index': gpu_index,
                    'utilization': utilization,
                    'memory_used': memory_used,
                })
        last_gpu_stats = gpu_stats
    except Exception as e:
        print("Error running nvidia-smi:", e)
        last_gpu_stats = []


def periodic_gpu_stats_update(interval: float, random_shifting: float):
    """
    Periodically updates the GPU stats every 'interval' seconds.
    """
    print(f"Starting GPU stats collection every {interval} seconds.")
    while True:
        get_gpu_stats()
        time.sleep(interval + np.random.randn() * random_shifting)


class GPUStatsHTTPRequestHandler(BaseHTTPRequestHandler):
    """
    HTTP request handler that returns the last captured GPU stats.
    """
    def do_GET(self):
        if self.path == '/gpu_stats':
            # Respond with the latest GPU stats in JSON format
            self.send_response(200)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            response = json.dumps(last_gpu_stats)
            self.wfile.write(response.encode('utf-8'))
        else:
            # Handle 404 Not Found
            self.send_response(404)
            self.end_headers()
            self.wfile.write(b'Endpoint not found.')


def run_http_server(port: int):
    """
    Runs an HTTP server that serves the last captured GPU stats.
    """
    server_address = ('', port)
    httpd = HTTPServer(server_address, GPUStatsHTTPRequestHandler)
    print(f"HTTP server is running on port {port}.")
    httpd.serve_forever()


def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--port', type=int)
    args = parser.parse_args()

    # Start the periodic GPU stats update in a separate thread
    stats_thread = threading.Thread(target=periodic_gpu_stats_update, args=(1., .1), daemon=True)
    stats_thread.start()

    # Run the HTTP server in the main thread
    run_http_server(port=args.port)


def run_gpu_stats_server(port: int):
    p = subprocess.Popen(
        ['python', os.path.join('extras', 'gpu_stats_server.py'), f'--port={port}'],
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
        env=os.environ.copy(),
    )
    return p


def fetch_gpu_stats(port: int):
    url = f'http://localhost:{port}/gpu_stats'  # Update with your server's address if different
    try:
        response = requests.get(url)
        response.raise_for_status()  # Raises HTTPError for bad responses (4XX, 5XX)
        gpu_stats = response.json()  # Parses the JSON response into a Python object
        return gpu_stats
    except requests.exceptions.RequestException as e:
        print(f"Error fetching GPU stats: {e}")
        return None


if __name__ == "__main__":
    main()
