import torch
import numpy as np
import argparse
import os
import json
from tqdm import tqdm

from utils.rf_dataset import mix, NpyStorage, window_start, choose_window


def mix_testset(soi, interference, target, sinrs, signal_length, sync_soi_by, use_rand_phase, ber_sync):
    soi_store = NpyStorage(soi, load_to_ram=False)
    interference_store = NpyStorage(interference, load_to_ram=False)
    print("Soi count:", len(soi_store))
    print("Interference count:", len(interference_store))

    dataset_len = min(len(soi_store), len(interference_store))
    print("Dataset length:", dataset_len)
    print("Element shape:", soi_store[0].shape)

    os.makedirs(target, exist_ok=True)
    meta = {"sinr": sinrs, "signal_length": signal_length,
            "sync_soi_by": sync_soi_by, "use_random_phase": use_rand_phase,
            "ber_sync": ber_sync}
    with open(os.path.join(target, "meta.json"), "w") as f:
        json.dump(meta, f, indent=4)
        f.write("\n")

    for sinr_id in range(len(sinrs)):
        sinr = sinrs[sinr_id]
        print("Processing sinr:", sinr)
        path = os.path.join(target, f"sinr{sinr_id}")
        os.makedirs(path, exist_ok=True)

        for i in tqdm(range(dataset_len)):
            soi = torch.from_numpy(soi_store[i])
            interference = torch.from_numpy(interference_store[i])
            offset = window_start(soi.numel(), signal_length, sync_soi_by)
            interference = choose_window(interference, signal_length)
            soi = soi[offset:offset + signal_length]
            mixture = mix(soi, interference, sinr, sinr, use_rand_phase)

            dct = {
                "sample_soi": soi.numpy(),
                "sample_mix": mixture.numpy(),
                "offset": offset,
            }
            np.save(os.path.join(path, f"sig_{i}.npy"), dct)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--soi", type=str)
    parser.add_argument("--interference", type=str)
    parser.add_argument("--target", type=str)
    parser.add_argument("--synchronized", type=bool, default=False)
    parser.add_argument("--ber_sync", type=int, default=16)

    args = parser.parse_args()
    mix_testset(args.soi,
                args.interference,
                args.target,
                sinrs=np.linspace(-30.0, 0.0, 11).tolist(),
                signal_length=40960,
                sync_soi_by=16 if args.synchronized else 1,
                use_rand_phase=True,
                ber_sync=args.ber_sync)
