from torch.backends.opt_einsum import strategy
from tqdm import tqdm
from typing import Union, Tuple, Literal
import os
import torch
from gnnboundary import *
from gnn_xai_common import GraphSampler, Trainer
from scripts.experiments_logging import initialize_logger
import warnings

logger = initialize_logger("graph_saver.log")
warnings.filterwarnings("ignore")

class GraphGenerator:
    def __init__(self, sampler, dataset, trainer, model,):
        self.sampler = sampler
        self.dataset = dataset
        self.trainer = trainer
        self.model = model
        self.save_dir = "graphs"

    def _get_save_dir(self, strategy, cls):
        """Create and return the save directory path based on strategy."""
        # Only add strategy to the path, let save_sampler handle dataset and class
        save_dir = os.path.join(self.save_dir, strategy)
        
        # Create directories if they don't exist
        os.makedirs(save_dir, exist_ok=True)
        return save_dir

    def __call__(self,
                 num_runs=1,
                 num_graphs=100,
                 logger=None,
                 show_progress=True,
                 save_graphs: bool = True,
                 add_non_successful: bool = False,
                 cls: Union[int, Tuple[int,int]] = None,
                 strategy: Literal['dynamic_boundary', 'cross_entropy', 'interpreter'] = None,
                 **kwargs):

        if save_graphs and cls is None or save_graphs and strategy is None:
            raise ValueError("cls must be provided if save_graphs is True and strategy must be provided if save_graphs is True")

        if save_graphs:
            save_dir = self._get_save_dir(strategy, cls)

        graphs_all_runs = []
        if logger is not None:
            logger.info("Generating boundary graphs....")
        iterator = tqdm(range(num_runs), disable=not show_progress)
        
        total_count = 0
        success_count = 0
        conv_iter = []
        for g in range(num_graphs):
            graphs = []
            success = False
            total_count += 1
            for i in iterator:
                self._reset_sampler_and_trainer()
                success = self.trainer.train(**kwargs)
                
                if success:
                    logger.info(f"Succesful generation of graph {g} at iteration {i}") if logger is not None else print(f"Succesful generation of graph {g} at iteration {i}")
                    graphs.append(self.sampler(k=1, mode='discrete', expected=True))
                    
                    # log the number of successful generations and the iteration at which it was successful
                    success_count += 1
                    conv_iter.append(self.trainer.iteration)
                    #self.trainer.get_training_success_rate()
                                
                    if save_graphs:
                        self.trainer.save_sampler(cls, root=save_dir)
                    break
                
                print(f"Failed to converge at {i}-th try. Retrying...") if logger is None else logger.info(f"Failed to converge at {i}-th try. Retrying...")
                if add_non_successful:
                    graphs.append(self.sampler(k=1, mode='discrete', expected=True))

            if not success:
                pass

            graphs_all_runs.append(graphs)
        
        if success_count != 0:
            average_convergence_rate = sum(conv_iter)/len(conv_iter)
            success_rate = success_count/total_count
            print(f"Average convergence rate: {average_convergence_rate}") if logger is None else logger.info(f"Average convergence rate: {average_convergence_rate}")
            print(f"Success rate: {success_rate}") if logger is None else logger.info(f"Success rate: {success_rate}")
        else:
            average_convergence_rate = 0
            success_rate = 0
            print("No successful generations") if logger is None else logger.info("No successful generations")
            
        return success_rate, conv_iter, graphs_all_runs

    def success_counter(self, num_runs:int, cls: Tuple[int, int], save_dir: str = None, logger=None, **kwargs):
        success_count = 0
        iteration_count = []
        iterator = tqdm(range(num_runs))
        for i in iterator:
            self._reset_sampler_and_trainer()
            success = self.trainer.train(**kwargs)
            if success:
                success_count += 1
                iteration_count.append(self.trainer.iteration)
                logger.info(f"Successful generation of graph at iteration {i}") if logger is not None else print(f"Successful generation of graph at iteration {i}")
                if save_dir is not None:
                    self.trainer.save_sampler(cls, root=save_dir)
        return success_count, iteration_count

    def _reset_sampler_and_trainer(self):
        self.sampler = GraphSampler(
            max_nodes = self.sampler.n,
            temperature=self.sampler.tau,
            num_node_cls=self.sampler.k,
            learn_node_feat=self.sampler.xi is not None
        )

        optimizer_class = type(self.trainer.optimizer[0])
        optimizer_hyparams = self.trainer.optimizer[0].defaults
        optimizer_new = optimizer_class(self.sampler.parameters(), **optimizer_hyparams)

        for param_group, old_param_group in zip(optimizer_new.param_groups, self.trainer.optimizer[0].param_groups):
            param_group['initial_lr'] = old_param_group.get('initial_lr', old_param_group['lr'])

        scheduler_type = type(self.trainer.scheduler)
        scheduler_hyperparams = {k: v for k, v in self.trainer.scheduler.__dict__.items()
                                 if k not in ['optimizer', 'base_lrs'] and not k.startswith('_')}
        scheduler_new = scheduler_type(optimizer_new, **scheduler_hyperparams)

        trainer_new = Trainer(sampler=self.sampler, discriminator=self.model,
                              criterion=self.trainer.criterion, optimizer=[optimizer_new],
                              scheduler=scheduler_new, dataset=self.dataset,
                              budget_penalty=self.trainer.budget_penalty)

        self.trainer = trainer_new


class GraphRetrainer:
    def __init__(self, sampler_dir: str, dataset, trainer, model, num_graphs: int = 100,
                 max_nodes: int = 25, temperature: float = 0.2, ):
        self.sampler_dir = sampler_dir
        self.dataset = dataset
        self.trainer = trainer
        self.model = model
        self.max_nodes = max_nodes
        self.temperature = temperature

        samplers = []
        files = os.listdir(self.sampler_dir)
        if len(files) == 1:
            self.not_run = True
            return
        else:
            self.not_run = False

        if num_graphs > len(files):
            num_graphs = len(files)
            logger.warning(
                "Number of graphs to load is greater than the number of files in the directory. Loading %s graphs",
                num_graphs)

        for i in range(num_graphs):
            sampler = GraphSampler(max_nodes=self.max_nodes, temperature=self.temperature,
                                   num_node_cls=len(self.dataset.NODE_CLS), learn_node_feat=True)
            sampler.load_state_dict(torch.load(os.path.join(self.sampler_dir, files[i])))
            samplers.append(sampler)

        self.samplers = samplers

        #move to backup directory
        backup_dir = self.sampler_dir + "_backup"
        os.makedirs(backup_dir, exist_ok=True)
        for file in files:
            os.rename(os.path.join(self.sampler_dir, file), os.path.join(backup_dir, file))

    def __call__(self, num_runs: int = 1, **kwargs):
        if self.not_run:
            logger.info("No graphs to retrain")
            return
        logger.info("Retraining boundary graphs....")
        for i in range(len(self.samplers)):
            counter = 0
            while counter < num_runs:
                self.trainer.sampler = self.samplers[i]
                success = self.trainer.train(**kwargs)
                logger.info(f"Retrained graph {i}")

                if success:
                    logger.info(f"Successful retraining of graph {i}")
                    self.trainer.save_sampler(root=self.sampler_dir, cls_idx=tuple(kwargs["target_probs"].keys()))
                    logger.info(f"Saved retrained graph {i} into {self.sampler_dir}")
                    break
                else:
                    counter += 1
                    logger.info(f"Failed to retrain graph {i}")
                    logger.info(f"Retrying...")

        logger.info("Retraining complete")



