import torch
import numpy as np
import argparse
import os
from tqdm import tqdm

from utils.rf_dataset import RFDataset


def mix_val(soi, interference, target, sinr_lo, sinr_hi, signal_length, sync_soi_by, use_rand_phase):
    dataset = RFDataset(soi,
                        interference,
                        sinr_lo,
                        sinr_hi,
                        signal_length,
                        sync_soi_by,
                        use_rand_phase,
                        load_to_ram=False)
    os.makedirs(target, exist_ok=True)
    for i in tqdm(range(len(dataset))):
        elem = dataset[i]
        dct = {
            "sample_mix": torch.view_as_complex(elem["mixture"]).numpy(),
            "sample_soi": torch.view_as_complex(elem["target"]).numpy(),
            "offset": elem["offset"]
        }
        np.save(os.path.join(target, f"sig_{i}.npy"), dct)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--soi", type=str)
    parser.add_argument("--interference", nargs="+", type=str)
    parser.add_argument("--target", type=str)
    parser.add_argument("--synchronized", type=bool, default=False)
    parser.add_argument("--signal_length", type=int)

    args = parser.parse_args()
    mix_val(args.soi,
            args.interference,
            args.target,
            sinr_lo=-33.0,
            sinr_hi=3.0,
            signal_length=args.signal_length,
            sync_soi_by=16 if args.synchronized else 1,
            use_rand_phase=True)
