import time
import json
import os
import sys
import random
import itertools
from math import ceil


class DecentralizedPlanner:
    def __init__(self, device_id):
        self.device_id = device_id
        self.C_mem_edge = 12 * 1e9  # Mock VRAM
        self.T_comp_edge = 11  # Mock TFLOPS

    def _estimate_latency(self, N, L, params, T_comp):
        µ_pref, µ_dec, b_draft, k_cand = params['µ_pref'], params['µ_dec'], params['b_draft'], params['k_cand']
        prefill_latency = (L / (T_comp * 0.8)) * ceil(N / µ_pref)
        decode_latency = (k_cand / (T_comp * 0.5)) * ceil(N / b_draft) + (1 / (T_comp * 1.0)) * ceil(N / µ_dec)
        return prefill_latency + decode_latency

    def _estimate_memory_req(self, N, L, params):
        k_v_cache = N * L * 2048
        activations = params['µ_pref'] * L * 1024
        return k_v_cache + activations

    def run_pipeline_planning(self, batch_size, seq_length):
        best_params = None
        min_latency = float('inf')

        possible_µ_pref = [i for i in range(1, batch_size + 1) if batch_size % i == 0]
        possible_µ_dec = [i for i in range(1, batch_size + 1) if batch_size % i == 0]
        possible_b_draft = [i for i in range(1, batch_size + 1) if batch_size % i == 0]
        possible_k_cand = [1, 2, 4]

        search_space = list(itertools.product(possible_µ_pref, possible_µ_dec, possible_b_draft, possible_k_cand))
        random.shuffle(search_space)

        for µ_pref, µ_dec, b_draft, k_cand in search_space[:100]:
            params = {'µ_pref': µ_pref, 'µ_dec': µ_dec, 'b_draft': b_draft, 'k_cand': k_cand}
            if self._estimate_memory_req(batch_size, seq_length, params) > self.C_mem_edge:
                continue

            current_latency = self._estimate_latency(batch_size, seq_length, params, self.T_comp_edge)
            if current_latency < min_latency:
                min_latency = current_latency
                best_params = params

        return best_params if best_params else {'µ_pref': 1, 'µ_dec': 1, 'b_draft': 1, 'k_cand': 1}


def safe_read_json(filepath, default_value):
    try:
        with open(filepath, 'r') as f:
            return json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        return default_value


def safe_write_json(filepath, data):
    with open(filepath, 'w') as f:
        json.dump(data, f, indent=4)


def main(device_id):
    TASK_QUEUE_FILE = "task_queue.json"
    STATUS_UPDATES_FILE = "status_updates.json"

    planner = DecentralizedPlanner(device_id)
    print(f"Edge Device '{device_id}' is running and waiting for tasks.")

    while True:
        task_queue = safe_read_json(TASK_QUEUE_FILE, [])
        my_task = None
        my_task_index = -1

        for i, task in enumerate(task_queue):
            if task.get("assign_to") == device_id:
                my_task = task
                my_task_index = i
                break

        if my_task:
            remaining_tasks = task_queue[:my_task_index] + task_queue[my_task_index + 1:]
            safe_write_json(TASK_QUEUE_FILE, remaining_tasks)

            if my_task.get("task_type") == "TERMINATE":
                print(f"Device '{device_id}' received termination signal. Shutting down.")
                break

            print(f"Device '{device_id}' picked up task {my_task['request_id']}.")

            # Stage 2: Decentralized Dynamic Pipeline Planning
            plan = planner.run_pipeline_planning(my_task['batch_size'], my_task['seq_length'])
            print(f"Device '{device_id}' generated plan: {plan}")

            # Simulate work
            processing_time = random.uniform(0.8, 2.5)
            time.sleep(processing_time)

            # Report completion
            status_updates = safe_read_json(STATUS_UPDATES_FILE, [])
            status_updates.append({
                "completed_by": device_id,
                "task_id": my_task['request_id'],
                "time_taken": processing_time
            })
            safe_write_json(STATUS_UPDATES_FILE, status_updates)
            print(f"Device '{device_id}' completed task {my_task['request_id']} in {processing_time:.2f}s.")

        else:
            time.sleep(0.5)


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python edge_device.py <device_id>")
        print("Example: python edge_device.py edge_server_1")
        sys.exit(1)

    device_id = sys.argv[1]
    main(device_id)