import os
import sys
import requests
import torch
from multiprocessing import Process

notebook_dir = os.path.dirname(os.path.abspath("__file__"))
slicer_root = os.path.dirname(notebook_dir)
if slicer_root not in sys.path:
    sys.path.append(slicer_root)

from SLICER.SLICER_opti import *
from SLICER.SLICER_config import *
from SLICER.SLICER_utils import *
from LLM_models.datautils import *
from LLM_models.evaluation import *


def load_LLM(model_name, tokenizer_name, device):
    hf_model_dir = os.path.expanduser(model_name)
    if not os.path.exists(hf_model_dir):
        print("You should download the LLaMa-2-7b-hf model from huggingface.")
    else:
        print('Loading "meta-llama/Llama-2-7b-hf"')

    print("CUDA is available?", torch.cuda.is_available())
    print("CUDA version:", torch.version.cuda)

    model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(
        device
    )
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    return model, tokenizer


import io


def serialize_tensors_binary(encoded_pkv_dict: dict) -> bytes:
    buffer = io.BytesIO()

    torch.save(encoded_pkv_dict, buffer)
    buffer.seek(0)
    return buffer.read()


def edge_generate_text(
    model,
    tokenizer,
    device,
    prompt: str,
    w_bar: int,
    max_new_tokens: int = 50,
    sc_config=None,
):
    print("==========================================================")
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    past_key_values = None
    generated_ids = input_ids.clone()

    if generated_ids.size(1) < w_bar:
        gen_count = 0
        while generated_ids.size(1) < w_bar:
            with torch.no_grad():
                cur_input = (
                    generated_ids if past_key_values is None else generated_ids[:, -1:]
                )
                outputs = model(
                    input_ids=cur_input,
                    past_key_values=past_key_values,
                    use_cache=True,
                    return_dict=True,
                    output_hidden_states=True,
                )
            next_token_logits = outputs.logits[:, -1, :]
            past_key_values = outputs.past_key_values

            next_token_id = torch.multinomial(
                torch.nn.functional.softmax(next_token_logits / 0.8, dim=-1),
                num_samples=1,
            )
            gen_count += 1
            generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

            if next_token_id.item() == tokenizer.eos_token_id:
                payload = {
                    "partial_text": "Good job!",
                    "max_new_tokens": max_new_tokens,
                }
                server_url = "http://127.0.0.1:8000/just_call"
                response = requests.post(server_url, json=payload)
                if response.status_code == 200:
                    result = response.json()
                    print("[Edge] just_call:", result)
                else:
                    raise RuntimeError(
                        f"Server error: {response.status_code}, {response.text}"
                    )
                return "ED Good job!"

        with torch.no_grad():
            outputs = model(
                input_ids=generated_ids,
                past_key_values=None,
                use_cache=False,
                return_dict=True,
                output_hidden_states=True,
                sc_config=sc_config,
                slicer_ftn=SLICER,
            )
        next_token_logits = outputs.logits[:, -1, :]
        next_token_id = torch.multinomial(
            torch.nn.functional.softmax(next_token_logits / 0.8, dim=-1), num_samples=1
        )
        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

        hs = outputs["hidden_states"][-1].cpu()

        final_input_ids = generated_ids.cpu()
        data_to_send = {
            "hidden_state": hs,
            "input_ids": final_input_ids,
            "max_new_tokens": max_new_tokens - gen_count,
        }

        binary_data = serialize_tensors_binary(data_to_send)
        server_url = "http://127.0.0.1:8000/continue_generation"
        response = requests.post(server_url, data=binary_data)
        if response.status_code == 200:
            result = response.json()
            return result["generated_text"]
        else:
            raise RuntimeError(f"Server error: {response.status_code}, {response.text}")
    else:
        partial_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        payload = {"partial_text": partial_text, "max_new_tokens": max_new_tokens}
        server_url = "http://127.0.0.1:8000/raw_continue_generation"
        response = requests.post(server_url, json=payload)
        if response.status_code == 200:
            result = response.json()
            return result["generated_text"]
        else:
            raise RuntimeError(f"Server error: {response.status_code}, {response.text}")


def simulate_edge(
    edge_id, samples, w_bar, max_new_tokens, model_name, tokenizer_name, datasets
):
    if edge_id < 4:
        device_in_use = torch.device("cuda:1")
    else:
        device_in_use = torch.device("cuda:2")

    print(f"[Edge {edge_id}] Start on {device_in_use}")

    edge_model, edge_tokenizer = load_LLM(model_name, tokenizer_name, device_in_use)
    edge_model.eval()

    cnt = 0
    sc_config = SCConfig(
        split_layer=20, s=0.6, lambd=0.0, delta=0.0, Q=[8,8,8], Q_n=[8,8,8]
    )

    for example in datasets["boolq"]:
        prompt = example["query"]
        print(f"[Edge {edge_id}] Prompt: {prompt}")
        try:
            output_text = edge_generate_text(
                model=edge_model,
                tokenizer=edge_tokenizer,
                device=device_in_use,
                prompt=prompt,
                w_bar=w_bar,
                max_new_tokens=max_new_tokens,
                sc_config=sc_config,
            )
            print(f"[Edge {edge_id}] Output: {output_text}")
        except Exception as e:
            print(f"[Edge {edge_id}] Error: {e}")
        cnt += 1
        if cnt >= samples:
            break

    edge_model.to("cpu")
    del edge_model
    torch.cuda.empty_cache()

    print(f"[Edge {edge_id}] Terminated")


if __name__ == "__main__":
    datasets_list = ["boolq"]
    tokenizer_name = "meta-llama/Llama-2-7b-hf"
    model_name = "/data/llm_models"
    datasets = get_dataset_dataset_processors(datasets_list)

    n_edge_device = 3

    samples = 50
    w_bar = 200
    max_new_tokens = 500

    processes = []
    for i in range(1, n_edge_device + 1):
        p = Process(
            target=simulate_edge,
            args=(
                i,
                samples,
                w_bar,
                max_new_tokens,
                model_name,
                tokenizer_name,
                datasets,
            ),
        )
        processes.append(p)

    for p in processes:
        p.start()

    for p in processes:
        p.join()

    res = requests.get("http://127.0.0.1:8000/server_stats")
    stats = res.json()
    print("Server overhead stats:", stats["summary"])
