import os
import torch
import argparse
import json
from ml_collections import ConfigDict
from tqdm import tqdm
import shutil

from utils.rf_dataset import DeterministicDataset
from utils.class_builder import ClassBuilder
from models.transformer import Transformer
from models.transformer_decoder import Transformer as TransformerDecoder
from models.transformer_quant import QuantOutputTransformer
from models.wavenet import Wave
from torch.utils.data import DataLoader
from utils.inference import generate
from rfcutils2.ber import compute_ber


MODELS_REGISTER = {
    "Transformer": Transformer,
    "TransformerDecoder": TransformerDecoder,
    "QuantOutputTransformer": QuantOutputTransformer,
    "WaveNet": Wave,
}


def load_test_dataset(path):
    with open(os.path.join(path, "meta.json")) as f:
        meta = json.load(f)
    sinrs = meta["sinr"]
    return list(zip(sinrs, [DeterministicDataset(os.path.join(path, f"sinr{i}"),
                                            load_to_ram=False) for i in range(len(sinrs))])), meta


def load_meta(path):
    with open(os.path.join(path, "meta.json")) as f:
        meta = json.load(f)
    return meta


def init_model(state_dict, config, device):
    model_builder = ClassBuilder(MODELS_REGISTER)
    model, _ = model_builder.build(config.model_config)
    model.to(device)
    model.load_state_dict({k[7:] if k.startswith("module.") else k : v for k, v in state_dict["model"].items()})
    model.eval()
    return model


# We only evaluate the sinrs such that sinr_id % world_size == rank. Set them to 1 and 0 to evaluate on all
def eval_model(model, config, dataset_path, expansion="id", multidiff_step=None, batch_size=10, beam_k=1, device="cuda", world_size=1, rank=0, silent=False):
    datasets, meta = load_test_dataset(dataset_path)
    ber_sync = meta["ber_sync"]
    mse_list = []
    ber_list = []
    base_signal_length = config.dataset_config[1]["signal_length"]

    for i in range(len(datasets)):
        if i % world_size != rank:
            continue
        sinr, dataset = datasets[i]
        if not silent:
            print(f"Evaluating sinr {sinr}")
        dataloader = DataLoader(dataset, batch_size=batch_size)
        mse = 0.0
        ber = 0.0
        iterable = tqdm(dataloader) if not silent else dataloader
        for batch in iterable:
            cur_size = batch["mixture"].shape[0]
            input = batch["mixture"].to(device)
            target = batch["target"].to(device)
            offsets = batch["offset"].to(device)
            preds = generate(model, input, config,
                             base_signal_length, expansion,
                             multidiff_step=multidiff_step, beam_k=beam_k)
            mse += ((preds - target) ** 2).mean().item() / len(dataset) * cur_size
            ber += compute_ber(preds, target, offsets, ber_sync, config.soi_type) / len(dataset) * cur_size

        if not silent:
            print("MSE:", mse)
            print("BER:", ber)
        mse_list.append(mse)
        ber_list.append(ber)

    return mse_list, ber_list


def eval_ckpt(ckpt_path, dataset_path, config_path=None, expansion="id", multidiff_step=None,
             batch_size=10, beam_k=1, base_signal_length=None, device="cuda"):
    state_dict = torch.load(ckpt_path, weights_only=False)
    if config_path is None:
        config = ConfigDict(state_dict["cfg"])
    else:
        config = ConfigDict(torch.load(config_path, weights_only=False)["cfg"])

    if base_signal_length is not None:
        cur_signal_length = config.dataset_config[1]["signal_length"]
        ratio = base_signal_length / cur_signal_length
        if config.model_config[0] == "QuantOutputTransformer":
            transformer_config = config.model_config[1].transformer_config
            transformer_config["max_seq_len"] = int(transformer_config["max_seq_len"] * ratio)
            transformer_config["block_size"] = int(transformer_config["block_size"] * ratio)
            config.dataset_config[1]["signal_length"] = base_signal_length
        else:
            raise ValueError("Unsupported model class for base_signal_length modification")

    model = init_model(state_dict, config, device)
    return eval_model(model, config, dataset_path,
                      expansion=expansion,
                      multidiff_step=multidiff_step,
                      batch_size=batch_size,
                      beam_k=beam_k,
                      device=device)
            

def save_eval_results(dataset_path, kwargs):
    results = []
    evals_path = os.path.join(dataset_path, "evals.json")
    evals_copy_path = os.path.join(dataset_path, "evals_copy.json")
    if os.path.isfile(evals_path):
        with open(evals_path) as f:
            results = json.load(f)
    eval_id = len(results)
    results.append({**kwargs, "id": eval_id})
    with open(evals_copy_path, "w") as f:
        json.dump(results, f, indent=4)
        f.write("\n")
    shutil.copyfile(evals_copy_path, evals_path)
    os.remove(evals_copy_path)
    return eval_id


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("ckpt_path", type=str)
    parser.add_argument("dataset_path", type=str)
    parser.add_argument("--config_path", type=str, default=None)
    parser.add_argument("--expansion", type=str, default="id")
    parser.add_argument("--multidiff_step", type=int, default=None)
    parser.add_argument("--batch_size", type=int, default=10)
    parser.add_argument("--beam_k", type=int, default=1)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--base_signal_length", type=int, default=None)

    args = parser.parse_args()

    mse_list, ber_list = eval_ckpt(
        ckpt_path=args.ckpt_path,
        dataset_path=args.dataset_path,
        config_path=args.config_path,
        expansion=args.expansion,
        multidiff_step=args.multidiff_step,
        batch_size=args.batch_size,
        beam_k=args.beam_k,
        base_signal_length=args.base_signal_length,
        device=args.device
    )

    exp_args = {
        "ckpt_path": args.ckpt_path
    }
    if args.config_path is not None:
        exp_args["config_path"] = args.config_path
    if args.expansion != "id":
        exp_args["expansion"] = args.expansion
    if args.multidiff_step is not None:
        exp_args["multidiff_step"] = args.multidiff_step
    if args.beam_k != 1:
        exp_args["beam_k"] = args.beam_k
    if args.base_signal_length is not None:
        exp_args["base_signal_length"] = args.base_signal_length

    exp_args["mse"] = mse_list
    exp_args["ber"] = ber_list
    state_dict = torch.load(args.ckpt_path, weights_only=False)
    exp_args["step"] = state_dict["step"]
    eval_id = save_eval_results(args.dataset_path, exp_args)
    print("Saved as evaluation", eval_id)

if __name__ == "__main__":
    main()
