import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import LlamaAttention
import time
import random
import itertools
from math import ceil
from pulp import LpProblem, LpVariable, lpSum, LpMinimize, LpBinary
import collections


class Request:
    def __init__(self, request_id, arrival_time, task_type):
        self.request_id = request_id
        self.arrival_time = arrival_time
        self.task_type = task_type


class CentralOffloadingScheduler:
    def __init__(self, testbed, kq=10):
        self.global_queue = collections.deque()
        self.testbed = testbed
        self.kq = kq
        print(f"K-Priority FCFS Scheduler initialized with Kq={self.kq}.")

    def add_request(self, request):
        self.global_queue.append(request)

    def schedule(self):
        assignments = []
        idle_servers = self.testbed.get_idle_servers()
        if not idle_servers or not self.global_queue:
            return assignments

        num_potential_tasks = min(len(self.global_queue), self.kq)
        decision_space = list(itertools.islice(self.global_queue, num_potential_tasks))

        num_assignments = min(len(idle_servers), len(decision_space))
        if num_assignments == 0:
            return assignments

        for i in range(num_assignments):
            request_to_assign = self.global_queue.popleft()
            server_to_assign = idle_servers[i]

            self.testbed.set_server_status(server_to_assign, 'busy')

            assignment = (request_to_assign, server_to_assign)
            assignments.append(assignment)

        return assignments


class ExperimentalTestbed:
    def __init__(self):
        self.edge_servers = {
            f"edge_server_{i}": {"gpu": "NVIDIA GeForce GTX TITAN X", "vram_gb": 12, "tflops": 11} for i in range(1, 5)
        }
        self.cloud_server = {"cloud_server": {"gpu": "NVIDIA RTX 4090", "vram_gb": 24, "tflops": 83}}
        self.all_devices = {**self.edge_servers, **self.cloud_server}

        self.network_bandwidth_bps = {
            ('edge_server_1', 'cloud_server'): 50e6, ('edge_server_2', 'cloud_server'): 50e6,
            ('edge_server_3', 'cloud_server'): 50e6, ('edge_server_4', 'cloud_server'): 50e6,
        }
        self.server_status = {name: 'idle' for name in self.all_devices.keys()}
        print("Experimental testbed initialized.")

    def get_interconnect_bandwidth(self, d1, d2):
        pair = tuple(sorted((d1, d2)))
        return self.network_bandwidth_bps.get(pair, 1e12)

    def get_idle_servers(self):
        return [name for name, status in self.server_status.items() if status == 'idle']

    def set_server_status(self, server_name, status):
        if server_name in self.server_status:
            self.server_status[server_name] = status

    def reset_all_servers_to_idle(self):
        self.server_status = {name: 'idle' for name in self.all_devices.keys()}


class LLMExperiment:
    def __init__(self, testbed):
        self.testbed = testbed
        self.models = {}
        self.offloading_scheduler = None
        print("LLMExperiment initialized.")

    def load_models(self):
        print("Loading model structures...")
        self.models["meta-llama/Llama-2-7b-hf"] = "loaded_model_placeholder"
        print("Model structures loaded.")

    def run_offloading_simulation(self, total_requests, kq_value):
        print(f"\n--- Running Offloading Simulation with {total_requests} requests and Kq={kq_value} ---")

        self.offloading_scheduler = CentralOffloadingScheduler(self.testbed, kq=kq_value)
        self.testbed.reset_all_servers_to_idle()

        for i in range(total_requests):
            self.offloading_scheduler.add_request(Request(i, i * 0.1, "inference"))

        current_time = 0.0
        processing_tasks = []
        total_wait_time = 0
        completed_tasks = 0

        while completed_tasks < total_requests:
            newly_finished_indices = [i for i, (completion_time, _) in enumerate(processing_tasks) if
                                      current_time >= completion_time]

            for i in sorted(newly_finished_indices, reverse=True):
                _, server_name = processing_tasks.pop(i)
                self.testbed.set_server_status(server_name, 'idle')
                completed_tasks += 1

            assignments = self.offloading_scheduler.schedule()

            for request, server_name in assignments:
                processing_time = random.uniform(0.5, 1.2)
                completion_time = current_time + processing_time
                processing_tasks.append((completion_time, server_name))
                total_wait_time += current_time - request.arrival_time

            current_time += 0.01

        avg_wait_time = total_wait_time / total_requests if total_requests > 0 else 0
        print(f"Simulation Finished. Total Time: {current_time:.2f}s")
        print(f"Average Request Wait Time: {avg_wait_time:.4f}s")
        print(f"System Throughput: {total_requests / current_time:.2f} req/s")


def add_offloading_scheduler_to_experiment(experiment, kq_value):
    print("\n--- Injecting Optimization Offloading Strategy ---")
    experiment.offloading_scheduler = CentralOffloadingScheduler(experiment.testbed, kq=kq_value)
    print("--- K-P-FCFS Scheduler injected successfully ---")
    return experiment


def main():
    testbed = ExperimentalTestbed()
    experiment = LLMExperiment(testbed)
    experiment.load_models()

    experiment = add_offloading_scheduler_to_experiment(experiment, kq_value=10)

    experiment.run_offloading_simulation(total_requests=100, kq_value=10)
    experiment.run_offloading_simulation(total_requests=100, kq_value=1)


if __name__ == "__main__":
    main()