import os
import random

import numpy as np
import pandas as pd
import json
from util import seed_everything
import msgpack
seed_everything(2023)
from tqdm import tqdm


def split_data(path):
    data = json.load(open(path + "/new.json", "r"))
    random.shuffle(data)
    step = [0,0.8,0.9,1.0]
    for i, split in enumerate(["train", "validation", "test"]):
        s,e = int(len(data)*step[i]), int(len(data)*step[i+1])
        split_data_raw = data[s:e]
        split_txt = open(path + f"/{split}.txt", "w")
        for item in split_data_raw:
            text = " ".join(item["text"])
            remi_token = "".join(item["event"])
            split_txt.write(text + " <SEP> " + remi_token + "\n")
        split_txt.close()
        print(split, len(split_data_raw))

def split_data_v3(path, save_path, truncated_length = 2560):
    os.makedirs(save_path, exist_ok=True)
    with open(path + '/RID.bin', 'rb') as RID:
        with open(path + f"/TOKEN.bin", "rb") as TOKEN:

            max_idx = 947658
            total_index = list(range(max_idx))
            random.shuffle(total_index)
            np.save(save_path + f"/total_index.npy", total_index)

            step = [0, 0.95, 0.975, 1.0]

            split_txt_pool = {}
            split_command_pool = {}
            target_split = {}
            for idx, split in enumerate(["train", "valid", "test"]):
                s, e = step[idx:idx + 2]
                s, e = int(s * max_idx), int(e * max_idx)
                for i in total_index[s:e]:
                    target_split[i] = split
                split_txt_pool[split] = [open(save_path + f"/{split}.txt", "w"), 0]
                split_command_pool[split] = []

            RID_unpacker = msgpack.Unpacker(RID, use_list=False)
            TOKEN_unpacker = msgpack.Unpacker(TOKEN, use_list=False)

            for i in tqdm(range(max_idx)):
                rid_info = next(RID_unpacker)
                tokens = next(TOKEN_unpacker)
                split = target_split[i]
                for piece in rid_info["pieces"]:
                    token_s, token_e = piece["token_begin"], piece["token_end"]
                    if token_e - token_s > truncated_length:
                        continue
                    split_txt_pool[split][0].write(" ".join(tokens[token_s:token_e]) + f'\n')
                    split_txt_pool[split][1] += 1
                    split_command_pool[split].append([piece["values"]])
                    if split_txt_pool[split][1] % 300000 == 0:
                        split_txt_pool[split][0].close()
                        split_txt_pool[split][0] = open(save_path + f"/{split}.txt", "a")
                # if i > 10000:
                #     break
            for split in ["train", "valid", "test"]:
                np.save(save_path + f"/{split}_command.npy", split_command_pool[split])
                print(split, len(split_command_pool[split]))
if __name__ == "__main__":
    path = "../../Text2Music_data/v3.1"
    save_path = path + f"/truncated_1600"
    split_data_v3(path, save_path, truncated_length=1600)



