import json
import os
import random
from dataclasses import asdict
from datetime import datetime

import numpy as np
import torch

from se.configs import PROJECT_ROOT, TrainConfig


def setup_experiment(cfg: TrainConfig) -> str:
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = os.path.join(PROJECT_ROOT, "logs", timestamp)
    os.makedirs(save_dir, exist_ok=True)
    return save_dir


def save_config(cfg: TrainConfig, save_dir: str) -> dict:
    os.makedirs(save_dir, exist_ok=True)
    config_dict = asdict(cfg)
    config_path = os.path.join(save_dir, "config.json")
    with open(config_path, "w", encoding="utf-8") as fp:
        json.dump(config_dict, fp, indent=2)
    return config_dict


def save_model_weights(model: torch.nn.Module, save_dir: str, name: str) -> str:
    os.makedirs(save_dir, exist_ok=True)
    checkpoint_path = os.path.join(save_dir, f"weights_{name}.pt")
    torch.save(model.state_dict(), checkpoint_path)
    return checkpoint_path


def run_name(cfg: TrainConfig) -> str:
    noise_type = getattr(cfg, "noise_type", "gaussian").lower()
    suffix_map = {"laplace": "_l", "uniform": "_u", "rayleigh": "_r"}
    noise_type_brief = suffix_map.get(noise_type, "")

    # model mode/ wrapper mode
    if cfg.model_cfg.wrapper_mode == "norm-equiv":
        model_mode = "wne"
    elif cfg.model_cfg.wrapper_mode == "norm-equiv-input":
        model_mode = "wnei"
    elif cfg.model_cfg.wrapper_mode == "scale-equiv":
        model_mode = "wse"
    elif cfg.model_cfg.model_mode == "norm-equiv":
        model_mode = "ne"
    elif cfg.model_cfg.model_mode == "scale-equiv":
        model_mode = "se"
    else:
        model_mode = "b"

    # model name
    model_name = cfg.model.lower()

    # pred mode
    if cfg.model_cfg.pred_mode == "residual":
        pred_mode = "res"
    else:
        pred_mode = "dir"

    # loss
    loss_type = cfg.loss_type.lower()

    # dataset type
    if cfg.train_dataset_type.lower() == "m":
        dataset_type = "m"
    else:
        dataset_type = "h"

    # patch size
    patch_size = cfg.s_patch_size

    # noise level
    min_noise = int(cfg.min_noise)
    max_noise = int(cfg.max_noise)
    if min_noise == max_noise:
        noise_level = f"{min_noise}"
    else:
        noise_level = f"{min_noise}-{max_noise}"

    base_name = f"{model_mode}_{model_name}_{pred_mode}_{loss_type}_{noise_level}{noise_type_brief}_{dataset_type}_{patch_size}"
    return f"{base_name}"
