from abc import ABC, abstractmethod
import git
import json
from pathlib import Path
from typing import Optional
import wandb
from libs.logger import Logger_classify, Logger_detect, Logger_ood, Logger_misclassification

try:
    from ray import tune, air
    from ray.air import session
except ImportError:
    print("Ray isn't installed")


SIMPLE_TEMPLATE = (
    "Step: {i}, train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, "
    "val_loss: {val_loss: .4f}, val_acc: {val_acc: .4f}, "
    "test_loss: {test_loss: .4f}, test_acc: {test_acc: .4f}"
)

SUMMARY_TEMPLATE = (
    "====================================================\n"
    "FINAL RESULTS AVERAGED OVER ALL SPLITS: "
    "avg_val_acc: {mean_val_acc: .4f}, val_std: {val_std: .4f}, "
    "avg_test_acc: {mean_test_acc: .4f}, test_std: {test_std: .4f}\n"
    "====================================================\n"
)


class ReporterInterface(ABC):
    @abstractmethod
    def report(self, metrics):
        pass


class PrintReporter(ReporterInterface):
    def __init__(self, template: Optional[str] = None):
        self.i = 1
        self.template = template

    def __call__(self, metrics: dict):
        self.report(metrics)

    def report(self, metrics: dict, *args, **kwargs):
        if self.template is not None:
            print(self.template.format(i=self.i, **metrics))
            self.i += 1
        else:
            print(metrics)


class WandBReporter(ReporterInterface):
    def __init__(
        self, cfg_reporter: str, cfg: dict,
    ):
        self.out = {}

        if cfg_reporter["task_name"] == "classify":
            self.logger = Logger_classify(cfg)
        elif cfg_reporter["task_name"] == "detect":
            self.logger = Logger_detect(cfg)
        elif cfg_reporter["task_name"] == "ood":
            self.logger = Logger_ood(cfg)
        elif cfg_reporter["task_name"] == "misclassification":
            self.logger = Logger_misclassification(cfg)
        else:
            raise ValueError(f"Reporter {cfg_reporter['task_name']} not found")
        # Get environment data
        try:
            repo = git.Repo(search_parent_directories=True)
            sha = repo.head.object.hexsha
        except:
            sha = ""

        # Construct out data structure
        out = {
            "dataset": cfg["dataset"]["name"],
            **{f"opt_{k}": v for k, v in cfg["optimizer"].items()},
            **{f"loss_fn_{k}": v for k, v in cfg["lossfn"].items()},
            **{f"model_{k}": v for k, v in cfg["model"].items()},
            **{f"exp_{k}": v for k, v in cfg["exp"].items()},
            **{f"dataset_{k}": v for k, v in cfg["dataset"].items()},
        }
        out["git_sha"] = sha
        out["checkpoint"] = []
        self.wandb = wandb.init(
            # set the wandb project where this run will be logged
            project=cfg_reporter["project"],
            entity=cfg_reporter["team"],
            # track hyperparameters and run metadata
            config=out,
            name=cfg_reporter["name"],
        )

    def __call__(self, metrics: dict, primary_metric: Optional[tuple[str, str]] = None):
        self.report(metrics, primary_metric)

    def report(self, metrics: dict, primary_metric: Optional[tuple[str, str]] = None):
        self.wandb.log(metrics)

    def close(self):    
        self.wandb.finish()

    def add_result(self, result):
        self.logger.add_result(result)

    def save_result(self, results, cfg):
        
        self.logger.save_result(results, cfg)

    def print_statistics(self, run=None):
        if run is None:
            report_str, report_dict, best_result = self.logger.print_statistics()
        else:
            report_str, report_dict = self.logger.print_statistics(run)
        self.report(report_dict, report_str)
        if run is None:
            return report_dict, best_result