import os
import torch
import numpy as np
from torch_geometric.data import Data
from torch.utils.data import TensorDataset
from make_graph import extract_graph_features, EmbedValue
from concurrent.futures import ThreadPoolExecutor, as_completed

# data_root = "./raw_data/"
# csv_files = ["data0000.csv"]
# output_root = "./output/"
# data_list = []
# data_list2 = []

# "meta-llama/Llama-2-7b-hf": {"source": "huggingface"},
# "bigscience/bloom-560m": {"source": "huggingface"},
# "google/gemma-2b": {"source": "huggingface"},
# "google/gemma-2-2b": {"source": "huggingface"},
# "microsoft/Phi-3-mini-128k-instruct": {"source": "huggingface"},
# "Qwen/Qwen2-0.5B": {"source": "huggingface"},

GPUs = {
    "GPU 0: NVIDIA L4": "L4",
    "GPU 0: Tesla T4": "T4",
    "GPU 0: NVIDIA A100-SXM4-40GB": "A100",
}
bits = {"torch.float16": 16, "torch.float32": 32}


def process_csv_files(
    data_root="./raw_data/", csv_files=["data0000.csv"], output_root="./output/"
):
    data_list = []
    data_list2 = []
    global_list = []
    for csv in csv_files:
        with open(data_root + csv) as f:
            for line in f.readlines():
                line = line.rstrip()
                items = line.split(",")
                gpu = str(items[0])
                if gpu == "A100":
                    gpu = "nvidia_A100_40G"
                elif gpu == "T4":
                    gpu = "nvidia_T4"
                    continue
                elif gpu == "L4":
                    gpu = "nvidia_L4"
                    
                gpu_num = int(items[1])
   
                llm = str(items[2])
                # if llm == "bloom":
                #     llm = "bigscience/bloom-560m"
                # elif llm == "llama":
                #     llm = "meta-llama/Llama-2-7b-hf"
                # elif llm == "gemma":
                #     llm = "google/gemma-2b"
                # elif llm == "gemma2":
                #     llm = "google/gemma-2-2b"
                # elif llm == "phi3":
                #     llm = "microsoft/Phi-3-mini-128k-instruct"
                # elif llm == "qwen2":
                #     llm = "Qwen/Qwen2-0.5B"
  

                bit = int(items[2])
                act = str(items[3])
                hidden = int(items[4])
                inter = int(items[5])
                layer = int(items[6])
                head = int(items[7])
                vob = int(items[8])
                batch = int(items[9])
                prompt_length = int(items[10])
                token_length = int(items[11])
                latency = float(items[12])
                energy = float(items[13])

                width = None
                if bit == 32:
                    width = "FP32"
                elif bit == 16:
                    width = "FP16"
                elif bit == 8:
                    width = "INT8"
                else:
                    width = "INT4"

                stage_temp = None
                if token_length > 1:
                    stage_temp = "decode"
                else:
                    stage_temp = "prefill"

                inference_config = {
                    "stage": stage_temp,
                    "batch_size": batch,
                    "seq_length": prompt_length,
                    "gen_length": token_length,
                    "w_quant": width,
                    "a_quant": width,
                    "kv_quant": width,
                    "use_flashattention": True,
                    "activation": act,
                    "hidden_size": hidden,
                    "inter_size": inter,
                    "layer_num": layer,
                    "head_num": head,
                    "vob_size": vob,
                }

                nodes, edges, global_f = extract_graph_features(
                    llm, gpu, inference_config
                )
                edge_index = torch.from_numpy(np.array(np.where(edges > 0))).type(
                    torch.long
                )
                node_features = np.array(nodes, dtype=np.float32)
                x = torch.from_numpy(node_features).type(torch.float)
                y = torch.FloatTensor([latency])

                data = Data(x=x, edge_index=edge_index, y=y)
                data_list.append(data)
                y2 = torch.FloatTensor([energy])
                data2 = Data(x=x, edge_index=edge_index, y=y2)
                data_list2.append(data2)
                global_feature = torch.from_numpy(global_f).type(torch.float)

                print(f"x: {x.size()}")
                print(f"y: {y.size()}")
                print(f"e: {edge_index.size()}")
                print(f"g: {global_feature.size()}")

                global_list.append(global_feature)

    if (
        len(data_list) != len(data_list2)
        or len(data_list) != len(global_list)
        or len(global_list) != len(data_list2)
    ):
        print("length error")
    torch.save(data_list, output_root + "latency_data.pt")
    torch.save(data_list2, output_root + "energy_data.pt")
    torch.save(global_list, output_root + "global_feature.pt")
    with open(output_root + "length.txt", mode="w", encoding="utf-8") as ref:
        ref.write(str(len(data_list)))
        ref.close()


# sys.argv[1]
# content_list.append("GPU,LLM,bitwidth,activation,hidden,inter,layer,head,vob,batch,plen,tlen,time,energy\n")


def merge_text(input_file_list, output_file):
    content_list = []

    for input_file in input_file_list:
        with open(input_file, mode="r") as f:
            for line in f.readlines():
                line = line.rstrip()
                if line.startswith("GPU"):
                    items = line.split(",")
                    for item in items:
                        item = item.strip()
                    items[0] = str(GPUs[items[0]])
                    items[2] = str(bits[items[2]])
                    new_line = ""
                    for item in items:
                        new_line += item + ","
                    content_list.append(new_line + "\n")
            f.close()

    with open(output_file, mode="w", encoding="utf-8") as ref:
        ref.writelines(content_list)
        ref.close()


def single_task(line):
    line = line.rstrip()
    items = line.split(",")
    gpu = str(items[0])
    real_gpu = None
    if gpu == "A100":
        real_gpu = "nvidia_A100_40G"
    elif gpu == "T4":
        real_gpu = "nvidia_T4"
    elif gpu == "L4":
        real_gpu = "nvidia_L4"
    gpu = EmbedValue.embed_GPU(gpu)

    llm = str(items[2])
    real_llm = None
    if llm == "bloom":
        real_llm = "bigscience/bloom-560m"
    elif llm == "llama":
        real_llm = "meta-llama/Llama-2-7b-hf"
    elif llm == "gemma":
        real_llm = "google/gemma-2b"
    elif llm == "gemma2":
        real_llm = "google/gemma-2-2b"
    elif llm == "phi3":
        real_llm = "microsoft/Phi-3-mini-128k-instruct"
    elif llm == "qwen2":
        real_llm = "Qwen/Qwen2-0.5B"
    elif llm == "mixtral":
        real_llm = "mistralai/Mixtral-8x7B-v0.1"

    llm = EmbedValue.embed_llm(llm)
    bit = int(items[3])
    act = str(items[4])
    hidden = int(items[5])
    inter = int(items[6])
    layer = int(items[7])
    head = int(items[8])
    vob = int(items[9])
    batch = int(items[10])
    prompt_length = int(items[11])
    token_length = int(items[12])
    prefill_energy = float(items[13])
    token_energy = float(items[14])
    total_energy = float(items[15])

    width = None
    if bit == 32:
        width = "FP32"
    elif bit == 16:
        width = "FP16"
    elif bit == 8:
        width = "INT8"
    else:
        width = "INT4"

    stage_temp = None
    if token_length > 1:
        stage_temp = "decode"
    else:
        stage_temp = "prefill"

    inference_config = {
        "stage": stage_temp,
        "batch_size": batch,
        "seq_length": prompt_length,
        "gen_length": token_length,
        "w_quant": width,
        "a_quant": width,
        "kv_quant": width,
        "use_flashattention": True,
        "activation": act,
        "hidden_size": hidden,
        "inter_size": inter,
        "layer_num": layer,
        "head_num": head,
        "vob_size": vob,
    }

    nodes, edges, global_f = extract_graph_features(
        real_llm, real_gpu, inference_config
    )

    bit = EmbedValue.embed_int(bit)
    act = EmbedValue.embed_act(act)
    hidden = EmbedValue.embed_float(hidden)
    inter = EmbedValue.embed_int(inter)
    layer = EmbedValue.embed_int(layer)
    head = EmbedValue.embed_int(head)
    vob = EmbedValue.embed_int(vob)
    batch = EmbedValue.embed_int(batch)
    prompt_length = EmbedValue.embed_int(prompt_length)
    token_length = EmbedValue.embed_int(token_length)
    prefill_energy = EmbedValue.embed_float(prefill_energy)
    token_energy = EmbedValue.embed_float(token_energy)
    total_energy = EmbedValue.embed_float(total_energy)

    global_data = np.concatenate(
        [
            np.array(global_f, dtype=np.float32).ravel(),
            gpu,
            llm,
            bit,
            act,
            hidden,
            inter,
            layer,
            head,
            vob,
            batch,
            prompt_length,
            token_length,
        ]
    )

    edge_index = np.array(np.where(edges > 0), dtype=np.int32)
    node_features = np.array(nodes, dtype=np.float32)

    if node_features.shape[0] < 31:
        node_features = np.pad(
            node_features, [(0, 15), (0, 0)], mode="constant", constant_values=0
        )

    prefill_energy = float(items[13])
    token_energy = float(items[14])
    total_energy = float(items[15])
    
    return node_features, edge_index, global_data, prefill_energy, token_energy, total_energy


def process_bigfile(input_file, prefill_output_file, decode_output_file):

    prefill_list = []
    prefill_item_list = []
    decode_list = []
    gdecode_list = []

    with open(input_file, mode="r") as f:
        for line in f.readlines():
            items = line.split(",")
            if int(items[11]) == 1:
                print(line)
                prefill_list.append(line)
                temp_list = []
                for index in range(0, 14):
                    temp_list.append(items[index])
                prefill_item_list.append(temp_list)
            else:
                decode_list.append(line)
        f.close()

    with open(prefill_output_file, mode="w", encoding="utf-8") as ref:
        ref.writelines(prefill_list)
        ref.close()

    for line in decode_list:
        items = line.split(",")
        hit = 0
        for t_list in prefill_item_list:
            equal = 1
            for index in range(0, 11):
                if t_list[index] != items[index]:
                    equal = 0
                    break

            # if equal == 0:
            #    print("cannot find")

            if equal == 1:
                latency = float(items[12]) - float(t_list[12])
                energy = float(items[13]) - float(t_list[13])
                new_line = ""
                for index in range(0, 12):
                    new_line += items[index] + ","
                new_line += str(latency) + ","
                new_line += str(energy) + ","
                gdecode_list.append(new_line + "\n")

    with open(decode_output_file, mode="w", encoding="utf-8") as ref:
        ref.writelines(gdecode_list)
        ref.close()


def process_bigfile_org(input_file, output_file):
    data_list = []
    edge_list = []
    node_list = []
    prefill_energy_list = []
    token_energy_list = []
    total_energy_list = []
    
    index = 0

    with open(input_file, mode="r") as f:
        processes = []
        with ThreadPoolExecutor(max_workers=8) as executor:
            for line in f.readlines():
                processes.append(executor.submit(single_task, line))
                index = index + 1
                if index > 50000:
                    break
                
        for task in as_completed(processes):
            node_features, edge_index, global_data, prefill_energy, token_energy, total_energy = task.result()
            data_list.append(global_data)
            edge_list.append(edge_index)
            node_list.append(node_features)
            prefill_energy_list.append(prefill_energy)
            token_energy_list.append(token_energy)
            total_energy_list.append(total_energy)


    # big_size = len(data_list)
    real_data = np.array(data_list, dtype=np.float32)
    print(real_data.shape)
    real_data = torch.from_numpy(real_data).type(torch.float32)
    torch.save(real_data, output_file + ".global.pt")

    # real_latency = np.array(latency_list, dtype=np.float32)
    # print(real_latency.shape)
    # real_latency = torch.from_numpy(real_latency).type(torch.float32)
    # torch.save(real_latency, output_file + ".latency.pt")

    real_prefill_energy = np.array(prefill_energy_list, dtype=np.float32)
    print(real_prefill_energy.shape)
    real_prefill_energy = torch.from_numpy(real_prefill_energy).type(torch.float32)
    torch.save(real_prefill_energy, output_file + ".prefill_energy.pt")

    real_token_energy = np.array(token_energy_list, dtype=np.float32)
    print(real_token_energy.shape)
    real_token_energy = torch.from_numpy(real_token_energy).type(torch.float32)
    torch.save(real_token_energy, output_file + ".token_energy.pt")

    real_total_energy = np.array(total_energy_list, dtype=np.float32)
    print(real_total_energy.shape)
    real_total_energy = torch.from_numpy(real_total_energy).type(torch.float32)
    torch.save(real_total_energy, output_file + ".total_energy.pt")

    real_edge = np.array(edge_list, dtype=np.int32)
    print(real_edge.shape)
    real_edge = torch.from_numpy(real_edge).type(torch.int32)
    torch.save(real_edge, output_file + ".edge.pt")

    real_node = np.array(node_list, dtype=np.float32)
    print(real_node.shape)
    real_node = torch.from_numpy(real_node).type(torch.float32)
    torch.save(real_node, output_file + ".node.pt")

    # my_dataset = TensorDataset(x, y)

    # print(x.size())
    # print(y.size())

    # torch.save(my_dataset, output_file + ".pt")

    # target = edge_list[0]
    # i = 0
    # for edge in edge_list:
    #    if torch.equal(target, edge) == False:
    #        print(data_list[i])
    #    i += 1

    # with open(output_file + ".txt", mode="w", encoding="utf-8") as ref:
    #    ref.write(str(big_size))
    #    ref.close()

def process_x_file(input_file="./raw_data/result_all.csv", output_file="new_all"):
    new_list = []
    with open(input_file) as f:
        for line in f.readlines():
            line = line.rstrip()
            items = line.split(",")
            gpu = str(items[0])
            if gpu == "T4":
                continue

            gpu_num = int(items[1])
            llm = str(items[2])
            if llm != "llama2_70b" and llm != "mixtral8x7b":
                continue
            
            if llm == "llama2_70b":
                llm = "llama"
            
            if llm == "mixtral8x7b":
                llm = "mixtral"

            bit = items[3]
            if bit == "f32":
                bit = 32
            elif bit == "f16":
                bit = 16
            elif bit == "i8":
                bit = 8

            act = str(items[4])
            hidden = int(items[5])
            layer = int(items[6])
            inter = int(items[7])
            head = int(items[8])
            vob = int(items[9])
            batch = int(items[10])
            prompt_length = int(float(items[11]))
            token_length = int(float(items[12]))
            latency = float(items[13])
            energy = float(items[14])
            
            new_line = f"{gpu},{gpu_num},{llm},{bit},{act},{hidden},{inter},{layer},{head},{vob},{batch},{prompt_length},{token_length},{latency},{energy}\n"
            new_list.append(new_line)
        f.close()
        
        with open(output_file, mode="w", encoding="utf-8") as ref:
            ref.writelines(new_list)
            ref.close()


def process_xx_file(input_file="./raw_data/result_all.csv", output_file="new_all"):
    new_list = []
    with open(input_file) as f:
        for line in f.readlines():
            line = line.rstrip()
            items = line.split(",")
            gpu = str(items[0])
            if gpu == "T4":
                continue
            new_list.append(line + "\n")
        f.close()
        
        with open(output_file, mode="w", encoding="utf-8") as ref:
            ref.writelines(new_list)
            ref.close()
  


def process_train_data(input_file, output_file):
    data_list = []
    edge_list = []
    node_list = []
    prefill_energy_list = []
    token_energy_list = []
    total_energy_list = []
    
    index = 0

    with open(input_file, mode="r") as f:
        processes = []
        with ThreadPoolExecutor(max_workers=16) as executor:
            for line in f.readlines():
                processes.append(executor.submit(single_task, line))
                index = index + 1
                if index > 50000:
                    break
                
        for task in as_completed(processes):
            node_features, edge_index, global_data, prefill_energy, token_energy, total_energy = task.result()
            data_list.append(global_data)
            edge_list.append(edge_index)
            node_list.append(node_features)
            prefill_energy_list.append(prefill_energy)
            token_energy_list.append(token_energy)
            total_energy_list.append(total_energy)


    real_data = np.array(data_list, dtype=np.float32)
    print(real_data.shape)
    real_data = torch.from_numpy(real_data).type(torch.float32)
    torch.save(real_data, output_file + ".global.pt")

    real_prefill_energy = np.array(prefill_energy_list, dtype=np.float32)
    print(real_prefill_energy.shape)
    real_prefill_energy = torch.from_numpy(real_prefill_energy).type(torch.float32)
    torch.save(real_prefill_energy, output_file + ".prefill_energy.pt")

    real_token_energy = np.array(token_energy_list, dtype=np.float32)
    print(real_token_energy.shape)
    real_token_energy = torch.from_numpy(real_token_energy).type(torch.float32)
    torch.save(real_token_energy, output_file + ".token_energy.pt")

    real_total_energy = np.array(total_energy_list, dtype=np.float32)
    print(real_total_energy.shape)
    real_total_energy = torch.from_numpy(real_total_energy).type(torch.float32)
    torch.save(real_total_energy, output_file + ".total_energy.pt")

    real_edge = np.array(edge_list, dtype=np.int32)
    print(real_edge.shape)
    real_edge = torch.from_numpy(real_edge).type(torch.int32)
    torch.save(real_edge, output_file + ".edge.pt")

    real_node = np.array(node_list, dtype=np.float32)
    print(real_node.shape)
    real_node = torch.from_numpy(real_node).type(torch.float32)
    torch.save(real_node, output_file + ".node.pt")

    
    
def main():
    out_dir = "./raw_data/backup/"
    # process_csv_files(data_root="./raw_data/", csv_files=["prefill.lei", "decode.lei"], output_root="./output/")
    # file_list = os.listdir(out_dir)
    # inputfile_list = []
    # for file in file_list:
    #    if file.endswith(".txt"):
    #        inputfile_list.append(out_dir + file)
    # merge_text(inputfile_list, out_dir + "big_text.lei")
    #process_bigfile_org(out_dir + "kaba0000.csv", out_dir + "first")
    #process_xx_file(out_dir + "kaba0000.csv", out_dir + "first.csv")
    #process_x_file(out_dir + "result_all.csv", "new_all.csv")
    # process_csv_files(data_root=out_dir, csv_files=csv_file_list, output_root="./output/")
    
    
    out_dir = "./raw_data/backup/"
    process_train_data(out_dir + "bloom.csv", out_dir + "bloom")
    process_train_data(out_dir + "gemma.csv", out_dir + "gemma")
    process_train_data(out_dir + "gemma2.csv", out_dir + "gemma2")
    process_train_data(out_dir + "llama.csv", out_dir + "llama")
    process_train_data(out_dir + "mixtral.csv", out_dir + "mixtral")
    process_train_data(out_dir + "qwen2.csv", out_dir + "qwen2")
    process_train_data(out_dir + "old.csv", out_dir + "old")
    process_train_data(out_dir + "new.csv", out_dir + "new")


if __name__ == "__main__":
    main()
