import time
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from datasets import load_dataset


class ExperimentalTestbed:

    def __init__(self):
        self.edge_servers = {
            f"edge_server_{i}": {"gpu": "NVIDIA GeForce GTX TITAN X", "vram_limit_gb": 12}
            for i in range(1, 5)
        }
        self.cloud_server = {"cloud_server": {"gpu": "NVIDIA RTX 4090", "vram_limit_gb": 24}}
        self.all_devices = {**self.edge_servers, **self.cloud_server}

        self.network_bandwidth_mbps = 50
        print("refresh：")
        for device, config in self.all_devices.items():
            print(f"- {device} set {config['gpu']}")
        print(f"bandwidth: {self.network_bandwidth_mbps} Mbps\n")

    def configure_network(self, bandwidth_mbps, latency_ms):

        print(f"--- reset：bandwidth {bandwidth_mbps} Mbps,  {latency_ms} ms ---")
        pass


class LLMExperiment:

    def __init__(self, testbed):
        self.testbed = testbed
        self.models = {}
        self.tokenizers = {}

    def load_models(self):

        model_names = ["meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-13b-hf"]
        for name in model_names:
            print(f"loading: {name} (int8)...")
            self.models[name] = LlamaForCausalLM.from_pretrained(name, load_in_8bit=True, device_map="auto")
            self.tokenizers[name] = LlamaTokenizer.from_pretrained(name)
            print(f"model {name} .\n")

    def load_datasets(self):

        print("dataset loading...")
        datasets = {
            "HumanEval": load_dataset("openai_humaneval", split="test"),
            "C-Eval": load_dataset("ceval/ceval-exam", name="average", split="val"),
            "SummEval": load_dataset("cnn_dailymail", "3.0.0", split="test"),
            "WikiText-2": load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        }

        datasets["WikiText-2"] = datasets["WikiText-2"].filter(lambda example: len(example['text'].split()) > 32)
        print("dateset ready。\n")
        return datasets

    def evaluate(self, model_name, dataset, task_name):

        model = self.models[model_name]
        tokenizer = self.tokenizers[model_name]

        print(f"---  {task_name} dateset evaluate {model_name} ---")


        total_latency = 0
        total_generated_tokens = 0


        sample = dataset[0]

        if task_name == "SummEval":
            prompt = sample['article'][:1024]
        elif task_name == "HumanEval":
            prompt = sample['prompt']
        elif task_name == "WikiText-2":

            prompt = " ".join(sample['text'].split()[:32])
        else:  # C-Eval
            prompt = f"Question: {sample['question']}\nA) {sample['A']}\nB) {sample['B']}\nC) {sample['C']}\nD) {sample['D']}"

        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

        start_time = time.time()
        max_new_tokens = 96 if task_name == "WikiText-2" else 100
        outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
        end_time = time.time()

        generated_tokens = outputs.shape[1] - inputs.input_ids.shape[1]
        latency = (end_time - start_time) * 1000
        latency_per_token = latency / generated_tokens if generated_tokens > 0 else float('inf')

        print(f"latency:")
        print(f"  -  {generated_tokens}  {latency:.2f} ms")
        print(f"  -  {latency_per_token:.2f} ms/token")


        batch_size = 8
        batch = [prompt] * batch_size
        inputs = tokenizer(batch, return_tensors="pt", padding=True).to("cuda")

        start_time = time.time()
        outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
        end_time = time.time()

        total_tokens_processed = batch_size * outputs.shape[1]
        throughput = total_tokens_processed / (end_time - start_time)

        print(f"\n (batchsize={batch_size}):")
        print(f"  -  {end_time - start_time:.2f} work with {total_tokens_processed} ")
        print(f"  - throughput: {throughput:.2f} tokens/s")
        print(f"--- {task_name} down ---\n")


def main():
    testbed = ExperimentalTestbed()

    experiment = LLMExperiment(testbed)

    experiment.load_models()

    print("\n--- DAS LOADING---")
    model_to_test = "meta-llama/Llama-2-7b-hf"

    das_hyperparams = {
        "block_size": 64,
        "theta_B": 0.85,
        "tau": 0.9,
        "theta_V_scale": 1e-4
    }

    llama_model = experiment.models[model_to_test]

    llama_model = inject_das_into_model(llama_model, das_hyperparams)
    print("--- DAS ---\n")

    experiment = add_offloading_scheduler_to_experiment(experiment, kq_value=10)

    experiment.run_offloading_simulation(total_requests=100, kq_value=10)
    experiment.run_offloading_simulation(total_requests=100, kq_value=1)

    datasets = experiment.load_datasets()

    testbed.configure_network(bandwidth_mbps=20, latency_ms=100)

    experiment.evaluate(model_to_test, datasets["HumanEval"], "HumanEval")
    experiment.evaluate(model_to_test, datasets["C-Eval"], "C-Eval")
    experiment.evaluate(model_to_test, datasets["SummEval"], "SummEval")
    experiment.evaluate(model_to_test, datasets["WikiText-2"], "WikiText-2")


if __name__ == "__main__":
    main()