import configs.base_configs as base_configs
from configs.base_configs import ModelType
import ml_collections
from quant.fsq import fsq_level_book
import numpy as np


def get_config():
    config = base_configs.get_config()

    config.model_type = ModelType.QUANTIZED_LLM

    config.dataset_config = [
        "RFDataset",
        dict(
            soi="$DATASETS/qpsk/train",
            interference="$DATASETS/cs5g1/interference",
            sinr_lo=-33.0,
            sinr_hi=3.0,
            signal_length=2560,
            load_to_ram=False,
        ),
    ]

    config.val_dataset = [
        "DeterministicDataset",
        dict(
            mixtures="$DATASETS/cs5g1/mix_val_2560",
            signal_length=2560,
            load_to_ram=False,
        ),
    ]

    config.test_datasets = [
        {
            "label": "testset1",
            "dataset_path": "$DATASETS/cs5g1/mix_test1",
            "expansion": "multidiff",
            "multidiff_step": 1280,
        },
        {
              "label": "val_like",
             "dataset_path": "$DATASETS/cs5g1/mix_test_val",
             "expansion": "multidiff",
             "multidiff_step": 1280,
        },
    ]
    config.test_every_epochs = 10

    config.model_config = [
        "QuantOutputTransformer",
        ml_collections.ConfigDict(
            dict(
                tokenizer_path = "$CKPTS/token025.ckpt",
                tokenizer_config = dict(
                    channels=[128, 256, 256],
                    fsq_bits=6,
                    num_transformer_blocks=4,
                    patch_channels=8,
                    resnet_count=3,
                    use_fsq=True,
                    signal_length=2560,
                ),
                transformer_config = dict(
                    n_encoder_layers=14,
                    n_decoder_layers=14,
                    embed_dim=768,
                    n_head=12,
                    bias=True,
                    dropout=0.0,
                    block_size=160,
                    encoder_causality=(None, None),
                    decoder_causality=(None, 0),
                    cross_causality=(None, None),
                    max_seq_len=160,
                    quantized_io=True,
                    input_vocab_size=72,
                    cond_dim=96,
                ),
                llm_style=True,
                tokenize_input=False,
                tokenizer_type="autoencoder"
            )
        ),
    ]

    config.optimizer_config = [
        "AdamW",
        ml_collections.ConfigDict(
            dict(
                lr=0.0001,
                weight_decay=0.01,
            )
        ),
    ]

    config.lr_scheduler_config = [
        "ReduceLROnPlateau",
        ml_collections.ConfigDict(),
    ]

    config.trainer_config.model_dir="$CKPTS/cs5g1_qllm_2560_new_fp16"
    config.trainer_config.batch_size=130
    config.trainer_config.fp16 = True
    config.trainer_config.distributed=True
    config.trainer_config.world_size=2
    config.trainer_config.save_every = 100000
    config.window_size = 16
    config.context_size = (16, 16)
    config.soi_type = "old"

    return config
