"""Dataclasses for optimization / prediction configuration, and optimization metrics."""

from copy import deepcopy
import warnings
import json
import random
from dataclasses import dataclass, asdict, field
import torch
import numpy as np
from torch import Tensor
import numpy as np
from typing import List, Optional, ClassVar, Dict
from utils.config import (
    MODEL_DICT,
    SPLITS,
    TASKS,
    get_train_x_range,
    get_train_y_range,
)
from utils.log import get_timestamp


@dataclass
class DataConfig:
    function_name: str
    dim_scatter_mode: str = "random_k"
    sigma: float = 0.0
    dim: Optional[int] = None


@dataclass
class TrainConfig:
    optimizer_type: str = "adamw"
    lr1: float = 1e-4
    lr2: float = 4e-5
    scheduler_type: str = "cosine"
    weight_decay: float = 1e-2
    num_warmup_steps: Optional[int] = None
    num_total_epochs: int = 100000
    num_burnin_epochs: int = 60000
    burnin_ratio: float = 0.8
    num_repeat_data: int = 1
    num_workers: int = 0
    prefetch_factor: Optional[int] = None

    def __post_init__(self):
        if self.num_workers == 0:
            self.prefetch_factor = None


@dataclass
class LossConfig:
    use_cumulative_rewards: bool = False
    batch_standardize: bool = True
    clip_rewards: bool = True
    loss_weight: float = 1.0
    discount_factor: float = 0.98
    pred_ratio: float = 1.0
    sum_over_trajectories: bool = False
    entropy_coeff: float = 0.0

    def __post_init__(self):

        if self.sum_over_trajectories:
            warnings.warn(
                f"'sum_over_trajectories' is set to True; "
                f"This may lead to incorrect reward computation if model is trained on trajectories with varying length. "
                f"Consider setting it to False."
            )


@dataclass
class ExperimentConfig:
    seed: int
    mode: str
    task: str
    model_name: str = "TAMO"
    expid: Optional[str] = None
    device: str = "cuda"
    override: bool = False
    resume: bool = False
    log_to_wandb: bool = True

    def __post_init__(self):
        if self.seed is None:
            self.seed = random.randint(1, 10000)
        assert isinstance(self.seed, int), f"Seed must be an int, got {type(self.seed)}"
        assert self.mode in SPLITS, f"Invalid mode:\t{self.mode}"
        assert self.task in TASKS, f"Invalid task: {self.task}"
        assert self.model_name in MODEL_DICT, f"Invalid model name:\t{self.model_name}"
        if self.device == "cuda":
            assert torch.cuda.is_available(), "CUDA is not available."

        if self.expid is None:
            self.expid = f"{self.model_name}_{self.seed}_{get_timestamp()}"


@dataclass
class PredictionConfig:
    batch_size: int
    min_nc: int = 2
    max_nc: int = 50
    nc: Optional[int] = None
    read_cache: bool = False
    write_cache: bool = False

    def __post_init__(self):
        assert self.min_nc >= 0, f"min_nc {self.min_nc} < 0"
        assert self.max_nc >= self.min_nc, f"max_nc {self.max_nc} < {self.min_nc}"

        if self.nc is not None:
            assert self.nc >= self.min_nc, f"nc {self.nc} < min_nc {self.min_nc}"
            assert self.nc <= self.max_nc, f"nc {self.nc} > max_nc {self.max_nc}"

    def to_dict(self) -> Dict[str, object]:
        return asdict(self)


@dataclass
class OptimizationConfig:
    use_grid_sampling: bool
    use_fixed_query_set: bool
    use_factorized_policy: bool
    use_time_budget: bool
    batch_size: int
    T: Optional[int] = None
    _T: int = field(init=False, repr=False, default=None)
    min_T: int = 10
    max_T: int = 100
    regret_type: str = "norm_ratio"
    num_initial_points: int = 1
    num_samples: int = 1
    dim_mask_gen_mode: str = "full"
    single_obs_x_dim: Optional[int] = None
    single_obs_y_dim: Optional[int] = None
    read_cache: bool = False
    write_cache: bool = False
    epsilon: float = 1.0

    @property
    def T(self):
        if self._T is None:
            # NOTE Sample T from [min_T, max_T] if not specified
            return random.randint(self.min_T, self.max_T)

        return self._T

    @T.setter
    def T(self, new_T):
        if new_T is None:
            return

        assert isinstance(new_T, int) and new_T > 0, "T must be a positive integer."
        self._T = new_T

    params_map: ClassVar[Dict[str, str]] = {
        "use_grid_sampling": "Grid",
        "use_fixed_query_set": "Fixq",
        "use_factorized_policy": "Fact",
        "use_time_budget": "Tbud",
        "batch_size": "B",
        "T": "T",
        "min_T": "MinT",
        "max_T": "MaxT",
        "regret_type": "Regr",
        "num_initial_points": "Nspt",
        "num_samples": "Nsmp",
    }

    def to_dict(self) -> Dict[str, object]:
        return asdict(self)


@dataclass
class SamplerConfig:
    x_dim_list: List[int]
    y_dim_list: List[int]
    x_range: List[float] = field(default_factory=lambda: get_train_x_range().copy())
    y_range: List[float] = field(default_factory=lambda: get_train_y_range().copy())
    sampler_list: List[str] = field(
        default_factory=lambda: [
            "multi_task_gp_prior_sampler",
            "multi_output_gp_prior_sampler",
        ]
    )
    sampler_weights: List[float] = field(default_factory=lambda: [1.0, 1.0])
    data_kernel_type_list: List[str] = field(
        default_factory=lambda: ["rbf", "matern32", "matern52"]
    )
    sample_kernel_weights: List[float] = field(default_factory=lambda: [1.0, 1.0, 1.0])
    lengthscale_range: List[float] = field(default_factory=lambda: [0.1, 2.0])
    std_range: List[float] = field(default_factory=lambda: [0.1, 1.0])
    min_rank: int = 1
    max_rank: Optional[int] = None
    p_iso: float = 0.5
    standardize: bool = False
    jitter: float = 1e-3
    max_tries: int = 6

    def to_dict(self) -> Dict[str, object]:
        return asdict(self)

    def assert_dims_within_limits(self, max_x_dim, max_y_dim):
        for dx in self.x_dim_list:
            assert dx <= max_x_dim, f"x_dim {dx} exceeds max_x_dim {max_x_dim}"
        for dy in self.y_dim_list:
            assert dy <= max_y_dim, f"y_dim {dy} exceeds max_y_dim {max_y_dim}"

