from dataclasses import dataclass
import torch
from typing import List, Tuple, Dict, Literal, Union
import logging
import os

from scripts.default_configs import *
from gnnboundary import *

@dataclass
class SamplingTrainerConfig:
    cls_pair: List[Tuple[int, int]]
    ckpt_path: str
    model_architecture: object = GCNClassifier
    hidden_channels: int = 32
    num_layers: int = 3
    max_nodes: int = 25
    temperature: float = 0.2
    learn_node_feat: bool = True
    optimizer: torch.optim.Optimizer = torch.optim.SGD
    lr: float = 1




class SamplingTrainer:
    """
    Trainer class for the sampling based training approach. It is a class that builds all the necessary components
    from scratch and should only be used if the user wants to train a model from scratch in a very simple way.
    Args:
        dataset: The dataset to be used for training
        config: A dataclass with the configuration of the trainer
        strategy: The strategy to be used for the training. Currently only "dynamic_boundary" is implemented
        budget_penalty: The budget penalty to be used for the training
        logger: A logger object to be used for logging
    """
    def __init__(self, dataset,
                 config: SamplingTrainerConfig,
                 strategy: Literal["dynamic_boundary", "cross_entropy", "interpreter"] = "dynamic_boundary",
                 budget_penalty: BudgetPenalty = BudgetPenalty(budget=10, order=2, beta=1),
                 logger: logging.Logger = None
                 ):
        self.dataset = dataset

        if logger is None:
            logging.basicConfig(
                level=logging.INFO,
                format="%(asctime)s - %(levelname)s - %(message)s",
                handlers=[
                    logging.StreamHandler(),
                    logging.FileHandler("training_sampler.log", mode="w")
                ]
            )

            logger = logging.getLogger(__name__)

        self.logger = logger

        self._budget_penalty = budget_penalty
        self._config = config
        self.trainer = None
        self.model = None
        self.sampler = {}
        self.mean_embeds = None
        self.weighted_criterion = {}

        self._build_sampler()
        self._build_model()

        self.strategy = strategy

        self._build_criterion()
        self._build_trainer()


    def _build_criterion(self):
        self.mean_embeds = self._mean_embeddings()
        for cls_pair in self._config.cls_pair:
                if self.strategy == "dynamic_boundary":
                    self.weighted_criterion[cls_pair] = get_default_criteria_dynamic_boundary(self.dataset.name,
                                                                                              self.mean_embeds,
                                                                                              cls_pair[0],
                                                                                              cls_pair[1])
                else:
                    raise NotImplementedError("Other strategies not implemented yet")

    def _build_trainer(self, budget_penalty=None):
        self.logger.info("Building trainer")
        trainer = {}
        for cls_pair in self._config.cls_pair:
            trainer[cls_pair] = Trainer(
            sampler=self.sampler[cls_pair],
            discriminator=self.model,
            criterion=self.weighted_criterion[cls_pair],
            optimizer=(o :=self._config.optimizer(self.sampler[cls_pair].parameters(), lr=self._config.lr)),
            scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
            dataset=self.dataset,
            budget_penalty=self._budget_penalty if budget_penalty is None else budget_penalty
        )

        self.trainer = trainer

    def _build_model(self):
        self.logger.info("Building model")
        self.model = self._config.model_architecture(node_features=len(self.dataset.NODE_CLS),
                                                    num_classes=len(self.dataset.GRAPH_CLS),
                                                    hidden_channels=self._config.hidden_channels,
                                                    num_layers=self._config.num_layers)
        self.model.load_state_dict(torch.load(self._config.ckpt_path))

    def _build_sampler(self):
        self.logger.info("Building sampler")

        for cls_pair in self._config.cls_pair:
            self.sampler[cls_pair] = GraphSampler(
                max_nodes=self._config.max_nodes,
                temperature=self._config.temperature,
                num_node_cls=len(self.dataset.NODE_CLS),
                learn_node_feat=self._config.learn_node_feat
            )

    def inject_criterion(self, criterion):
        self.weighted_criterion = criterion
        self._build_trainer()

    @property
    def config(self):
        return self._config

    @config.setter
    def config(self, new_config):
        self._config = new_config
        self._build_model()
        self._build_sampler()
        self._build_trainer()

    @property
    def criterion(self):
        return self.weighted_criterion

    @criterion.setter
    def criterion(self, new_criterion):
        self.inject_criterion(new_criterion)


    @property
    def budget_penalty(self):
        return self.budget_penalty

    @budget_penalty.setter
    def budget_penalty(self, new_budget_penalty):
        self.budget_penalty = new_budget_penalty
        self._build_trainer(budget_penalty=budget_penalty)


    def _mean_embeddings(self) -> List[torch.Tensor]:
        dataset_list_gt = self.dataset.split_by_class()
        embeddings = [d.model_transform(self.model, key="embeds").mean(dim=0) for d in dataset_list_gt]

        return embeddings

    def train(self, cls_pair: Tuple[int, int],
              iterations: int, target_probs: Dict[int, Tuple[float, float]] = None,
              target_size:int = 30, w_budget_init: int = 1, w_budget_inc: float = 1.1,
              w_budget_dec: float = 0.95, k_samples: int = 16):
        self.model.eval()
        self.sampler[cls_pair].train()
        if target_probs is None:
            target_probs = {cls_pair[0]: (0.35, 0.65), cls_pair[1]: (0.35, 0.65)}
        else:
            target_probs = {cls_pair[0]: list(target_probs.values())[0], cls_pair[1]: list(target_probs.values())[1]}

        return self.trainer[cls_pair].train(iterations=iterations, target_probs=target_probs, target_size=target_size,
                           w_budget_init=w_budget_init, w_budget_inc=w_budget_inc, w_budget_dec=w_budget_dec,
                           k_samples=k_samples)

    def __call__(self, *args, **kwargs):
        return self.trainer(*args, **kwargs)

    def evaluate(self, cls_pair: Tuple[int, int], threshold: float = 0.5):
        trainer = self.trainer[cls_pair]
        fail_counter = 0
        while fail_counter < 10:
            try:
                G = trainer.sampler.sample(threshold=threshold)
                break
            except Exception:
                fail_counter += 1
                continue
        else:
            logging.error("An empty graph was generated 10 times in a row!")
        pred = trainer.predict(G)
        logits = pred["logits"].mean(dim=0)
        probs = pred["probs"].mean(dim=0)

        return logits.cpu().numpy(), probs.cpu().numpy()

    def train_evaluate_all(self, iterations: int,
                           target_probs: Dict[int, Tuple[float, float]] = None,
                           target_size:int = 30, w_budget_init: int = 1,
                           w_budget_inc: float = 1.1,
                           w_budget_dec: float = 0.95,
                           k_samples: int = 16,
                           eval_threshold: float = 0.5,
                           max_repeats: int = 1000,):
        eval_results = {}
        graphs = {}
        for cls_pair in self._config.cls_pair:
            while max_repeats > 0:
                self._reset_sampler_and_trainer()
                if self.train(cls_pair, iterations=iterations, target_probs=target_probs, target_size=target_size,
                           w_budget_init=w_budget_init, w_budget_inc=w_budget_inc, w_budget_dec=w_budget_dec,
                           k_samples=k_samples):
                    self.logger.info("Training successful for cls pair: %s", cls_pair)
                    break
                else:
                    self.logger.info("Retrying")
                    max_repeats -= 1

            eval_result = self.evaluate(cls_pair, eval_threshold)
            g = self.trainer[cls_pair].evaluate(threshold=eval_threshold, show=False)

            eval_results[cls_pair] = eval_result
            graphs[cls_pair] = g

        return eval_results, graphs

    def _reset_sampler_and_trainer(self):

        for cls_pair in self._config.cls_pair:
            self.sampler[cls_pair] = GraphSampler(
                max_nodes = self.sampler[cls_pair] .n,
                temperature=self.sampler[cls_pair] .tau,
                num_node_cls=self.sampler[cls_pair] .k,
                learn_node_feat=self.sampler[cls_pair] .xi is not None
            )

            optimizer_class = type(self.trainer[cls_pair].optimizer[0])
            optimizer_hyparams = self.trainer[cls_pair] .optimizer[0].defaults
            optimizer_new = optimizer_class(self.sampler[cls_pair].parameters(), **optimizer_hyparams)

            for param_group, old_param_group in zip(optimizer_new.param_groups, self.trainer[cls_pair].optimizer[0].param_groups):
                param_group['initial_lr'] = old_param_group.get('initial_lr', old_param_group['lr'])


            scheduler_type = type(self.trainer[cls_pair].scheduler)
            scheduler_hyperparams = {k: v for k, v in self.trainer[cls_pair].scheduler.__dict__.items()
                                     if k not in ['optimizer', 'base_lrs'] and not k.startswith('_')}
            scheduler_new = scheduler_type(optimizer_new, **scheduler_hyperparams)

            trainer_new = Trainer(sampler=self.sampler[cls_pair], discriminator=self.model,
                                  criterion=self.trainer[cls_pair].criterion, optimizer=[optimizer_new],
                                  scheduler=scheduler_new, dataset=self.dataset,
                                  budget_penalty=self.trainer[cls_pair].budget_penalty)

            self.trainer[cls_pair] = trainer_new


if __name__ == "__main__":
    import warnings
    warnings.filterwarnings("ignore")
    dataset = ENZYMESDataset(seed=12345)
    model_kwargs = get_model_kwargs(dataset, dataset.name)
    config = SamplingTrainerConfig(cls_pair=[(0, 3), (0,4)],
                                   ckpt_path=f"../ckpts/{dataset.name.lower()}.pt",
                                   model_architecture=GCNClassifier,
                                   hidden_channels=model_kwargs["hidden_channels"],
                                   num_layers=model_kwargs["num_layers"],
                                   lr=0.5)
    trainer = SamplingTrainer(dataset, config)
    print(trainer.train_evaluate_all(2000))
