from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional

import torch

from data_generator import TASKS, NUM_CONSTANTS


@dataclass
class Args:
    """
    Arguments.
    """

    task: str = field(
        metadata={"help": 'Task name.'}
    )
    objective: str = field(
        metadata={
            "help": "REINFORCE, RLOO, RLOO-topk",
            "choices": ["REINFORCE", "RLOO", "RLOO-topk"],
        }
    )
    train_epochs: int = field(
        metadata={"help": "The number of training epochs."},
    )
    max_variables: int = field(
        metadata={"help": "Maximum number of unique variables in one rule. Must be in [[1,52]], due to the limitation of einsum."},
    )
    num_rules: int = field(
        metadata={"help": "Number of rules to be learned."},
    )
    cwa: bool = field(
        default=False,
        metadata={"help": "Whether closed world assumption. Otherwise, open world assumption."},
    )
    max_aux_arity: int = field(
        default=0,
        metadata={"help": "Maximum arity for auxiliary predicates. `0` means no auxiliary predicate."},
    )
    num_aux_predicates: int = field(
        default=0,
        metadata={"help": "Number of auxiliary predicates to be invented."},
    )
    max_occurrence_in_body: int = field(
        default=1,
        metadata={"help": "Maximum number of occurrences, for any predicate, in one rule's body."},
    )
    
    eval_every_epoch: int = field(
        default=1,
        metadata={"help": "Evaluate every `eval_every_epoch` epochs."},
    )
    max_train_inference_steps: int = field(
        default=10,
        metadata={"help": "Maximum inference steps in training. Should not be too small, otherwise the eventual truth values is not correctly computed. Recommended: number of optimal steps + 1, if number of optimal steps is known."},
    )
    max_eval_inference_steps: int = field(
        default=100,
        metadata={"help": "Maximum inference steps in soft evaluation."},
    )
    
    max_body_atoms: int = field(
        default=2,
        metadata={"help": "Maximum number of body atoms."},
    )
    predicate_embed_dim: int = field(
        default=10,
        metadata={"help": "Predicate's embedding dimension."},
    )
    variable_embed_dim: int = field(
        default=10,
        metadata={"help": "Variable's embedding dimension."},
    )
    rule_head_atom_embed_dim: int = field(
        default=10,
        metadata={"help": "Rule's head atom embedding dimension."},
    )
    rule_body_atom_embed_dim: int = field(
        default=10,
        metadata={"help": "Rule's body atom embedding dimension."},
    )
    num_sample_vars: int = field(
        default=32,
        metadata={"help": "Number of samples for variable selection."},
    )
    num_sample_atoms: int = field(
        default=32,
        metadata={"help": "Number of samples for atom selection per each variables sample."},
    )
    entropy_coeff_init: float = field(
        default=0.1,
        metadata={"help": "Initialized coefficient for entropy regularization."},
    )
    entropy_coeff_final: float = field(
        default=0,
        metadata={"help": "Final coefficient for entropy regularization."},
    )
    entropy_coeff_anneal_epochs: int = field(
        default=200,
        metadata={"help": "Annealing steps for entropy regularization."}
    )
    # entropy_coeff: float = field(
    #     default=0.1,
    #     metadata={"help": "Coefficient for entropy regularization."},
    # )
    # loss_masked_fill_value: int = field(
    #     default=int(-1e3),
    #     metadata={"help": "Value filled for `-inf` in eventually predicted truth values."}
    # )
    lr: float = field(
        default=0.001,
        metadata={"help": "Learning rate of AdamW optimizer."},
    )
    beta1: float = field(
        default=0.9,
        metadata={"help": "Beta 1 of AdamW optimizer."},
    )
    beta2: float = field(
        default=0.999,
        metadata={"help": "Beta 2 of AdamW optimizer."},
    )
    weight_decay: float = field(
        default=0.01,
        metadata={"help": "Weight decay of AdamW optimizer."},
    )
    max_grad_norm: Optional[float] = field(
        default=None,
        metadata={"help": "Maximum of gradient norm."},
    )
    rloo_topk: Optional[int] = field(
        default=None,
        metadata={"help": "K largest value selected for baseline."},
    )
    remove_irrelevant_vars: bool = field(
        default=False,
        metadata={"help": "Whether remove the irrelevant variables."},
    )
    stop_per_sampling_rules: bool = field(
        default=False,
        metadata={"help": "Whether stop if a sampling rule achieves 1.0 balanced accuracy."},
    )
    sampling_bacc_non_parallel_compute: bool = field(
        default=False,
        metadata={"help": "Whether use non parallel compute of balanced accuracy during training sampling. This can save lots of memory."},
    )
    # patience: int = field(
    #     default=100,
    #     metadata={"help": "Patience of `ReduceLROnPlateau`. Note that since `ReduceLROnPlateau` monitors `train_soft_mse`, `eval_every_epoch` also affect the true patience."},
    # )

    train_num_constants: Optional[int] = field(
        default=None,
        metadata={"help": "The number of training constants."},
    )
    eval_num_constants: Optional[int] = field(
        default=None,
        metadata={"help": "The number of evaluation constants."},
    )

    seed: int = field(
        default=42,
        metadata={"help": "Random seed."},
    )
    use_gpu: bool = field(
        default=True,
        metadata={"help": "Use gpu or not. If `use_gpu` but no cuda device is found, an error will be raised."},
    )
    num_intraop_threads: int = field(
        default=2,
        metadata={"help": "Number of intraop threads."},
    )
    num_interop_threads: int = field(
        default=4,
        metadata={"help": "Number of interop threads. This is also the number of CUDA streams."},
    )
    log_dir: str | Path = field(
        default="./log",
        metadata={"help": "Log directory, which stores result, log, and tensorboard data."},
    )
    log_file: Optional[str | Path] = field(
        default=None,
        metadata={"help": "Log file name. If not given, __post_init__ will figure out a default path."},
    )
    log_period_to_disk: Optional[int] = field(
        default=None,
        metadata={"help": "Period of flushing log to disk. In seconds."},
    )
    log_buffer_capacity: Optional[int] = field(
        default=None,
        metadata={"help": "Buffer capacity before flushing log to disk. In number of records."},
    )
    result_file: Optional[str | Path] = field(
        default=None,
        metadata={"help": "Result file name (.json). If not given, __post_init__ will figure out a default path."},
    )
    debug: bool = field(
        default=False,
        metadata={"help": "In debug mode or not."},
    )

    def __post_init__(self):
        if not ("GeoILP/basic/" in self.task):
            raise ValueError(f"Unknown task: {self.task}")


        now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        self.log_dir = Path(self.log_dir)
        log_name = f"{self.task.replace("/", "-")}_{"CWA" if self.cwa else "OWA"}_seed{self.seed}_{now}"
        if self.log_file is None:
            self.log_file: Path = self.log_dir / f"{log_name}.log"
        else:
            self.log_file: Path = self.log_dir / self.log_file
        if self.result_file is None:
            self.result_file: Path = self.log_dir / f"{log_name}.json"
        else:
            self.result_file: Path = self.log_dir / self.result_file
        self.tensorboard_dir = self.log_dir / "tensorboard" / log_name

        if self.log_file.exists():
            raise FileExistsError(f"Log file already exists: {self.log_file}")
        if self.result_file.exists():
            raise FileExistsError(f"Output file exists: {self.result_file}")
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.tensorboard_dir.mkdir(parents=True, exist_ok=False)
        
        assert 1 <= self.max_variables and self.max_variables <= 52
        assert self.max_occurrence_in_body >= 1
        assert self.num_sample_vars >= 1
        assert self.train_epochs % self.eval_every_epoch == 0

        if self.num_aux_predicates > 0:
            assert self.max_aux_arity > 0

        if "GeoILP" not in self.task:
            self.train_num_constants = self.train_num_constants or NUM_CONSTANTS[self.task]["train"]
            self.eval_num_constants = self.eval_num_constants or NUM_CONSTANTS[self.task]["eval"]

        if self.rloo_topk:
            assert self.objective == "RLOO-topk"

        if self.use_gpu:
            assert torch.cuda.is_available(), f"Cuda device not found. Try to remove `--use_gpu`."
            self.device = torch.device("cuda:0")
        else:
            self.device = torch.device("cpu")


    def to_json(self):
        d = {}
        for k, v in vars(self).items():
            d[k] = str(v)
        return d
