import logging
import sys
import time
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Optional
import torch
import torch.backends.cudnn as cudnn
import torch.distributed
from omegaconf import MISSING
from torch import nn
from torch.utils.data import TensorDataset
from torchmetrics import MetricTracker
from dinov3.data import SamplerType, make_data_loader, make_dataset
from dinov3.data.adapters import DatasetWithEnumeratedTargets
from dinov3.data.transforms import CROP_DEFAULT_SIZE, get_target_transform, make_classification_eval_transform
from dinov3.distributed import get_rank, get_world_size
from dinov3.eval.data import (
    create_train_dataset_dict,
    extract_features_for_dataset_dict,
    get_num_classes,
    split_train_val_datasets,
)
from dinov3.eval.helpers import args_dict_to_dataclass, cli_parser, write_results
from dinov3.eval.metrics import ClassificationMetricType, build_classification_metric
from dinov3.eval.setup import ModelConfig, load_model_and_context
from dinov3.eval.utils import average_metrics, evaluate, extract_features
from dinov3.eval.utils import save_results as default_save_results_func
from dinov3.run.init import job_context
from dinov3.utils.dtype import as_torch_dtype
logger = logging.getLogger("dinov3")
RESULTS_FILENAME = "results-log-regression.csv"
MAIN_METRICS = ["top-1(_mean)?"]
try:
    from sklearnex import patch_sklearn
    patch_sklearn()
except ImportError:
    logger.warning("Can't import sklearnex. If installed, that speeds up scikit-learn 10-100x")
try:
    from sklearn.linear_model import LogisticRegression as sklearnLogisticRegression
    from sklearn.multiclass import OneVsRestClassifier
except ImportError:
    logger.warning("Can't import scikit-learn. This is necessary for evaluating log regression")
    raise ImportError
C_POWER_RANGE = torch.linspace(-6, 5, 45)
_CPU_DEVICE = torch.device("cpu")
@dataclass
class TrainConfig:
    dataset: str = MISSING
    val_dataset: Optional[str] = None
    val_metric_type: ClassificationMetricType = ClassificationMetricType.MEAN_ACCURACY
    batch_size: int = 256
    num_workers: int = 5
    tol: float = 1e-12
    train_features_device: str = "cpu"  # device to gather train features (cpu, cuda, cuda:0, etc.)
    train_dtype: str = "float64"  # data type to convert the train features to
    max_train_iters: int = 1_000
@dataclass
class EvalConfig:
    test_dataset: str = MISSING
    batch_size: int | None = None
    num_workers: int = 5
    test_metric_type: Optional[ClassificationMetricType] = None
@dataclass
class TransformConfig:
    resize_size: int = CROP_DEFAULT_SIZE
    crop_size: int = CROP_DEFAULT_SIZE
@dataclass
class FewShotConfig:
    enable: bool = False
    k_or_percent: Optional[float] = None
    n_tries: int = 1
@dataclass
class LogregEvalConfig:
    model: ModelConfig
    train: TrainConfig = field(default_factory=TrainConfig)
    eval: EvalConfig = field(default_factory=EvalConfig)
    transform: TransformConfig = field(default_factory=TransformConfig)
    few_shot: FewShotConfig = field(default_factory=FewShotConfig)
    save_results: bool = False
    output_dir: str = ""
class LogRegModule(nn.Module):
    def __init__(self, C, multi_label=False, logreg_config=TrainConfig):
        super().__init__()
        self.dtype = as_torch_dtype(logreg_config.train_dtype)
        self.device = torch.device(logreg_config.train_features_device)
        assert self.device == _CPU_DEVICE, f"SKLearn can only work on CPU device, got {self.device}"
        self.estimator = sklearnLogisticRegression(
            penalty="l2",
            solver="lbfgs",
            C=C,
            max_iter=logreg_config.max_train_iters,
            n_jobs=-1,
            tol=logreg_config.tol,
        )
        if multi_label:
            self.estimator = OneVsRestClassifier(self.estimator, n_jobs=-1)
    def forward(self, samples, targets):
        samples_device = samples.device
        samples = samples.to(dtype=self.dtype, device=self.device)
        if self.device == _CPU_DEVICE:
            samples = samples.numpy()
        probas = self.estimator.predict_proba(samples)
        return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets}
    def fit(self, train_features, train_labels):
        train_features = train_features.to(dtype=self.dtype, device=self.device)
        train_labels = train_labels.to(dtype=self.dtype, device=self.device)
        if self.device == _CPU_DEVICE:
            train_features = train_features.numpy()
            train_labels = train_labels.numpy()
        self.estimator.fit(train_features, train_labels)
def evaluate_logreg_model(*, logreg_model, test_metric, test_data_loader, save_results_func=None):
    key = "metrics"  # We need only one key as we have only one metric
    postprocessors, metrics = {key: logreg_model}, {key: test_metric}
    _, eval_metrics, accumulated_results = evaluate(
        nn.Identity(),
        test_data_loader,
        postprocessors,
        metrics,
        torch.cuda.current_device(),
        accumulate_results=save_results_func is not None,
    )
    if save_results_func is not None:
        save_results_func(**accumulated_results[key])
    return eval_metrics
def train_for_C(*, C, train_features, train_labels, logreg_config: TrainConfig):
    logreg_model = LogRegModule(C, multi_label=len(train_labels.shape) > 1, logreg_config=logreg_config)
    logreg_model.fit(train_features, train_labels)
    return logreg_model
def sweep_C_values(
    *,
    train_features,
    train_labels,
    val_data_loader,
    val_metric,
    logreg_config: TrainConfig,
):
    metric_tracker = MetricTracker(val_metric, maximize=True)
    ALL_C = 10**C_POWER_RANGE
    logreg_models: Dict[float, Any] = {}
    train_features_device = torch.device(logreg_config.train_features_device)
    train_dtype = as_torch_dtype(logreg_config.train_dtype)
    train_features = train_features.to(dtype=train_dtype, device=train_features_device)
    train_labels = train_labels.to(device=train_features_device)
    for i in range(get_rank(), len(ALL_C), get_world_size()):
        C = ALL_C[i].item()
        logger.info(
            f"Training for C = {C:.4g}, dtype={train_dtype}, "
            f"features: {train_features.shape}, {train_features.dtype}, "
            f"labels: {train_labels.shape}, {train_labels.dtype}"
        )
        logreg_models[C] = train_for_C(
            C=C,
            train_features=train_features,
            train_labels=train_labels,
            logreg_config=logreg_config,
        )
    gather_list: List[Dict[float, Any]] = [{} for _ in range(get_world_size())]
    torch.distributed.all_gather_object(gather_list, logreg_models)
    for logreg_dict in gather_list:
        logreg_models.update(logreg_dict)
    gather_list.clear()
    for i in range(len(ALL_C)):
        metric_tracker.increment()
        C = ALL_C[i].item()
        evals = evaluate_logreg_model(
            logreg_model=logreg_models.pop(C),
            test_metric=metric_tracker,
            test_data_loader=val_data_loader,
        )
        logger.info(f"Trained for C = {C:.4g}, accuracies = {evals}")
        best_stats, which_epoch = metric_tracker.best_metric(return_step=True)
        best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()}
        if which_epoch["top-1"] == i:
            best_C = C
    logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.4g}")
    return best_stats, best_C
def make_logreg_data_loader(batch_size: int, num_workers: int, features: torch.Tensor, labels: torch.Tensor):
    return make_data_loader(
        dataset=DatasetWithEnumeratedTargets(
            TensorDataset(features, labels), pad_dataset=True, num_replicas=get_world_size()
        ),
        batch_size=batch_size,
        num_workers=num_workers,
        sampler_type=SamplerType.DISTRIBUTED,
        drop_last=False,
        shuffle=False,
    )
def get_best_logreg_with_features(
    *,
    train_features: torch.Tensor,
    train_labels: torch.Tensor,
    val_features: torch.Tensor,
    val_labels: torch.Tensor,
    val_metric,
    concatenate_train_val: bool,
    train_config: TrainConfig,
):
    val_data_loader = make_logreg_data_loader(
        train_config.batch_size, train_config.num_workers, val_features, val_labels
    )
    _, best_C_t = sweep_C_values(
        train_features=train_features,
        train_labels=train_labels,
        val_data_loader=val_data_loader,
        val_metric=val_metric,
        logreg_config=train_config,
    )
    if concatenate_train_val:
        logger.info("Best parameter found, concatenating features")
        train_features = torch.cat((train_features, val_features))
        train_labels = torch.cat((train_labels, val_labels))
    logger.info("Training final model")
    logreg_model = train_for_C(
        C=best_C_t,
        logreg_config=train_config,
        train_features=train_features,
        train_labels=train_labels,
    )
    return logreg_model
def make_transform(config: TransformConfig):
    if config.resize_size / config.crop_size != 1:
        logger.warning(f"Default resize / crop ratio is 1, here we have {config.resize_size} / {config.crop_size}")
    transform = make_classification_eval_transform(resize_size=config.resize_size, crop_size=config.crop_size)
    return transform
def make_train_val_datasets(train_config: TrainConfig, few_shot_config: FewShotConfig, transform):
    train_dataset = make_dataset(
        dataset_str=train_config.dataset,
        transform=transform,
        target_transform=get_target_transform(train_config.dataset),
    )
    if train_config.val_dataset is not None:
        val_dataset = make_dataset(
            dataset_str=train_config.val_dataset,
            transform=transform,
            target_transform=get_target_transform(train_config.val_dataset),
        )
    else:
        split_percentage = 0.01 if few_shot_config.enable else 0.1
        train_dataset, val_dataset = split_train_val_datasets(train_dataset, split_percentage=split_percentage)
    train_dataset_dict = create_train_dataset_dict(
        train_dataset,
        few_shot_eval=few_shot_config.enable,
        few_shot_k_or_percent=few_shot_config.k_or_percent,
        few_shot_n_tries=few_shot_config.n_tries,
    )
    num_classes = get_num_classes(train_dataset)
    return train_dataset_dict, val_dataset, num_classes
def make_test_dataset_and_data_loader(model, config: EvalConfig, transform, gather_on_cpu: bool):
    test_dataset = make_dataset(
        dataset_str=config.test_dataset,
        transform=transform,
        target_transform=get_target_transform(config.test_dataset),
    )
    test_features, test_labels = extract_features(
        model, test_dataset, config.batch_size, config.num_workers, gather_on_cpu=gather_on_cpu
    )
    assert isinstance(config.batch_size, int)
    test_data_loader = make_logreg_data_loader(config.batch_size, config.num_workers, test_features, test_labels)
    return test_dataset, test_data_loader
def eval_log_regression_with_model(*, model: torch.nn.Module, autocast_dtype, config: LogregEvalConfig):
    """
    Implements the "standard" process for log regression evaluation:
    The value of C is chosen by training on train_dataset and evaluating on
    val_dataset. Then, the final model is trained on a concatenation of
    train_dataset and val_dataset, and is evaluated on test_dataset.
    If there is no val_dataset, the value of C is the one that yields
    the best results on a random 10% subset of the train dataset
    """
    start = time.time()
    cudnn.benchmark = True
    transform = make_transform(config.transform)
    config.eval.batch_size = config.eval.batch_size or config.train.batch_size
    train_dataset_dict, val_dataset, num_classes = make_train_val_datasets(config.train, config.few_shot, transform)
    with torch.autocast("cuda", dtype=autocast_dtype):
        gather_on_cpu = torch.device(config.train.train_features_device) == _CPU_DEVICE
        train_data_dict = extract_features_for_dataset_dict(
            model, train_dataset_dict, config.train.batch_size, config.train.num_workers, gather_on_cpu=gather_on_cpu
        )
        logger.info("Choosing hyperparameters on the val dataset")
        val_features, val_labels = extract_features(
            model, val_dataset, config.train.batch_size, config.train.num_workers, gather_on_cpu=gather_on_cpu
        )
        test_dataset, test_data_loader = make_test_dataset_and_data_loader(model, config.eval, transform, gather_on_cpu)
    model.cpu()
    torch.cuda.empty_cache()
    val_metric = build_classification_metric(config.train.val_metric_type, num_classes=num_classes, dataset=val_dataset)
    test_metric_type = config.eval.test_metric_type or config.train.val_metric_type
    test_metric = build_classification_metric(test_metric_type, num_classes=num_classes, dataset=test_dataset)
    save_results_func = None
    if config.save_results:
        save_results_func = partial(default_save_results_func, output_dir=config.output_dir)
    results_dict = {}
    for _try in train_data_dict.keys():
        logreg_model = get_best_logreg_with_features(
            train_features=train_data_dict[_try]["train_features"],
            train_labels=train_data_dict[_try]["train_labels"],
            val_features=val_features,
            val_labels=val_labels,
            val_metric=val_metric,
            concatenate_train_val=not config.few_shot.enable,
            train_config=config.train,
        )
        if len(train_data_dict) > 1 and save_results_func is not None:
            split_results_saver = partial(save_results_func, filename_suffix=str(_try))
        else:
            split_results_saver = save_results_func
        eval_metrics = evaluate_logreg_model(
            logreg_model=logreg_model,
            test_metric=test_metric.clone(),
            test_data_loader=test_data_loader,
            save_results_func=split_results_saver,
        )
        results_dict[_try] = {k: v.item() * 100.0 for k, v in eval_metrics["metrics"].items()}
    if len(train_data_dict) > 1:
        results_dict = average_metrics(results_dict)
    else:
        results_dict = {**results_dict[_try]}
    logger.info(f"Log regression evaluation done in {int(time.time() - start)}s")
    logger.info("Training of the supervised logistic regression on frozen features completed.")
    results_string = "\n".join([f"{k}: {results_dict[k]:.4g}" for k in sorted(results_dict.keys())])
    logger.info("Results:\n" + results_string)
    torch.distributed.barrier()
    return results_dict
def benchmark_launcher(eval_args: dict[str, object]) -> dict[str, Any]:
    """Initialization of distributed and logging are preconditions for this method"""
    dataclass_config, output_dir = args_dict_to_dataclass(eval_args=eval_args, config_dataclass=LogregEvalConfig)
    model, model_context = load_model_and_context(dataclass_config.model, output_dir=output_dir)
    results_dict = eval_log_regression_with_model(
        model=model, config=dataclass_config, autocast_dtype=model_context["autocast_dtype"]
    )
    write_results(results_dict, output_dir, RESULTS_FILENAME)
    return results_dict
def main(argv=None):
    if argv is None:
        argv = sys.argv[1:]
    eval_args = cli_parser(argv)
    with job_context(output_dir=eval_args["output_dir"]):
        benchmark_launcher(eval_args=eval_args)
    return 0
if __name__ == "__main__":
    main()
