import argparse
import os
import random

import numpy as np
import torch
from torch import cuda
from torch.backends import cudnn

import losses
import metrics
import models
import scalers


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--source_domains", type=str, required=True, nargs="+")
    parser.add_argument("--target_domain", type=str, required=True)
    parser.add_argument("--forecast_horizon", type=int, default=10)
    parser.add_argument("--lookback_multiplier", type=float, default=5)
    parser.add_argument("--model", type=str, default="NHiTS", choices=models.__all__)
    parser.add_argument("--loss", type=str, default="SMAPE", choices=losses.__all__)
    parser.add_argument(
        "--regularizer",
        type=str,
        default="Sinkhorn",
        choices=[
            None,
            "Wasserstein",
            "Sinkhorn",
            "EnergyMMD",
            "GaussianMMD",
            "LaplacianMMD",
        ],
    )
    parser.add_argument("--temperature", type=float, default=1)
    parser.add_argument(
        "--reduce_type", type=str, default="max", choices=["max", "sum"]
    )
    parser.add_argument(
        "--scaler", type=str, default="softmax", choices=scalers.__all__
    )
    parser.add_argument("--metric", type=str, default="SMAPE", choices=metrics.__all__)
    parser.add_argument("--learning_rate", type=float, default=2e-5)
    parser.add_argument("--num_lr_cycle", type=int, default=50)
    parser.add_argument("--batch_size", type=int, default=2**12)
    parser.add_argument("--num_iters", type=int, default=1_000)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument(
        "--dtype",
        type=str,
        default="float32",
        choices=["float16", "float32", "float64"],
    )
    parser.add_argument(
        "--data_size",
        type=int,
        default=75_000,
        help="fix the data size to this number, if None, use all data",
    )
    args = parser.parse_args()
    if args.regularizer is None:
        args.temperature = 0
        args.scaler = None
    if args.reduce_type == "sum":
        args.reduce_type = "add"
    args.lookback_horizon = int(args.forecast_horizon * args.lookback_multiplier)
    args.pred_learning_rate = args.learning_rate
    args.align_learning_rate = args.learning_rate * args.temperature
    return args


def seed_everything(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    cuda.manual_seed_all(seed)
    cudnn.benchmark = True
    cudnn.deterministic = True
