import logging
import time

import numpy as np
import torch
from scipy.stats import stats
from sklearn.metrics import mean_absolute_error, mean_squared_error, root_mean_squared_error, r2_score
from torch_geometric.graphgym import get_current_gpu_usage
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.logger import infer_task, Logger
from torch_geometric.graphgym.utils.io import dict_to_json, dict_to_tb

from graphgym.metric_ocb import is_valid_DAG, is_valid_Circuit, is_graph_valid, is_valid_circuit_graph, our_is_valid_circuit
import warnings

from graphgym.utils import remove_edges_with_attribute_value, torch_geometric_to_igraph
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

class CustomLogger(Logger):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Whether to run comparison tests of alternative score implementations.
        self.test_scores = False
        # self._sim_out = []

    # basic properties
    def basic(self):

        # warnings.warn("Changed loss definition, could be wrong", UserWarning)
        stats = {
            'loss': round(self._loss / self._size_current, max(8, cfg.round)),
            'lr': round(self._lr, max(8, cfg.round)),
            'params': self._params,
            'time_iter': round(self.time_iter(), cfg.round),
        }
        gpu_memory = get_current_gpu_usage()
        if gpu_memory > 0:
            stats['gpu_memory'] = gpu_memory
        return stats
    
    def regression(self):
        true, pred = torch.cat(self._true), torch.cat(self._pred)
        reformat = lambda x: round(float(x), cfg.round)
        return {
            'mae': reformat(mean_absolute_error(true, pred)),
            'r2': reformat(r2_score(true, pred, multioutput='uniform_average')),
            'spearmanr': reformat(eval_spearmanr(true.numpy(),
                                                 pred.numpy())['spearmanr']),
            'mse': reformat(mean_squared_error(true, pred)),
            'rmse': reformat(root_mean_squared_error(true, pred)),
        }
    
    def generative(self):

        reformat = lambda x: round(float(x), cfg.round)
        self.process_prediction()
        
        if cfg.dataset.directed:
            return {
                "valid_DAG" : reformat(self.valid_DAG(self._pred)),
                "valid_circuit" : reformat(self.valid_circuit(self._pred)),
            }
        else:
            return {
                "valid_graph" : reformat(self.valid_graph(self._pred)),
                "valid_circuit" : reformat(self.valid_circuit_graph(self._pred)),
            }
    
    def pin_prediction(self):
        """Compute pin prediction metrics: accuracy, precision, recall, F1"""
        from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
        
        reformat = lambda x: round(float(x), cfg.round)
        
        # Collect all predictions and true labels
        all_predictions = []
        all_true_labels = []
        
        for batch in self._pred:
            # Extract predictions and labels using the same logic as pin_loss
            pred_logits_idx = []
            for i in range(batch.learnable_edge_index.size(1)):
                edge = batch.learnable_edge_index[:, i].T
                found = batch.edge_index.T == edge
                found = torch.nonzero(found[:, 0] & found[:, 1])
                pred_logits_idx.append(found[0].item())

            pred_logits = batch.edge_attr[pred_logits_idx]
            
            # Get predicted classes (argmax of logits)
            predicted_classes = torch.argmax(pred_logits, dim=1)
            
            # Add to collections
            all_predictions.extend(predicted_classes.cpu().numpy())
            all_true_labels.extend(batch.labels.long().cpu().numpy())
        
        if len(all_predictions) == 0:
            return {
                'accuracy': 0.0,
                'precision': 0.0,
                'recall': 0.0,
                'f1': 0.0
            }
        
        # Calculate metrics
        accuracy = accuracy_score(all_true_labels, all_predictions)
        precision = precision_score(all_true_labels, all_predictions, average='binary', zero_division=0)
        recall = recall_score(all_true_labels, all_predictions, average='binary', zero_division=0)
        f1 = f1_score(all_true_labels, all_predictions, average='binary', zero_division=0)
        
        return {
            'accuracy': reformat(accuracy),
            'precision': reformat(precision),
            'recall': reformat(recall),
            'f1': reformat(f1)
        }
    
    def custom(self):
        """Return custom statistics accumulated during training"""
        reformat = lambda x: round(float(x) / self._size_current, max(8, cfg.round))
        custom_stats = {}
        
        # Add any custom statistics that were passed via kwargs to update_stats
        for key, val in self._custom_stats.items():
            custom_stats[key] = reformat(val)
        
        return custom_stats
    
    def process_prediction(self):
        
        """
            Processes the prediction by sampling the predicted node and edge attributes
            and converting the PyTorch Geometric Data objects to igraph objects.
        """

        processed_pred = []
        for batch in self._pred:
            for i in range(batch.num_graphs):  # Number of graphs in the batch
                graph = batch.get_example(i)   # Get the i-th graph
                x1_probs = F.softmax(graph.x, dim=-1) # (B, D, S-1)
                e1_probs = F.softmax(graph.edge_attr, dim=-1)
                x1 = Categorical(x1_probs).sample() # (B, D)
                e1 = Categorical(e1_probs).sample() # (B, S-1)     
                graph.x = x1
                graph.edge_attr = e1
                graph = remove_edges_with_attribute_value(graph, 0)
                i_graph = torch_geometric_to_igraph(graph, cfg.dataset.directed)
                processed_pred.append(i_graph)
        self._pred = processed_pred

    def valid_graph(self, prediction):

        valid = 0
        total = len(prediction)
        for graph in prediction:
            if is_graph_valid(graph,cfg.dataset.subcircuit):
                valid += 1
        return valid / total
    
    def valid_circuit_graph(self, prediction):
        
        valid = 0
        total = len(prediction)
        for graph in prediction:
            if our_is_valid_circuit(graph):
                valid += 1
        return valid / total
    
    def valid_DAG(self, prediction):
        valid = 0
        total = len(prediction)
        for graph in prediction:
            if is_valid_DAG(graph,cfg.dataset.subcircuit):
                valid += 1
        return valid / total
        
    def valid_circuit(self, prediction):
        valid = 0
        total = len(prediction)
        for graph in prediction:
            if is_valid_Circuit(graph,cfg.dataset.subcircuit):
                valid += 1
        return valid / total

    def update_stats(self, true, pred, loss, lr=0, time_used=0, params=None,
                     **kwargs):

        batch_size = cfg.train.batch_size
        
        self._iter += 1
        self._true.append(true)
        self._pred.append(pred)
        self._size_current += batch_size
        self._loss += loss * batch_size
        self._lr = lr
        self._params = params
        self._time_used += time_used
        self._time_total += time_used
        for key, val in kwargs.items():
            if key not in self._custom_stats:
                self._custom_stats[key] = val * batch_size
            else:
                self._custom_stats[key] += val * batch_size

    def write_epoch(self, cur_epoch, custom_metrics=False):
        start_time = time.perf_counter()
        basic_stats = self.basic()

        if self.task_type == 'generative':
            if custom_metrics:
                task_stats = self.generative()
            else:
                task_stats = {}
        elif self.task_type == 'regression':
            task_stats = self.regression()
        elif self.task_type == 'pin_prediction':
            if custom_metrics:
                task_stats = self.pin_prediction()
            else:
                task_stats = {}
        else:
            raise ValueError('Task has to be generative, regression, or pin_prediction')

        epoch_stats = {'epoch': cur_epoch,
                       'time_epoch': round(self._time_used, cfg.round)}
        eta_stats = {'eta': round(self.eta(cur_epoch), cfg.round),
                     'eta_hours': round(self.eta(cur_epoch) / 3600, cfg.round)}
        custom_stats = self.custom()

        if self.name == 'train':
            stats = {
                **epoch_stats,
                **eta_stats,
                **basic_stats,
                **task_stats,
                **custom_stats
            }
        else:
            stats = {
                **epoch_stats,
                **basic_stats,
                **task_stats,
                **custom_stats
            }

        # print
        logging.info('{}: {}'.format(self.name, stats))
        # json
        dict_to_json(stats, '{}/stats.json'.format(self.out_dir))
        # tensorboard
        if cfg.tensorboard_each_run:
            dict_to_tb(stats, self.tb_writer, cur_epoch)
        self.reset()
        if cur_epoch < 3:
            logging.info(f"...computing epoch stats took: "
                         f"{time.perf_counter() - start_time:.2f}s")
        return stats

    def save_extra_info(self, stats: dict, cur_epoch:int):
        if cfg.tensorboard_each_run:
            dict_to_tb(stats, self.tb_writer, cur_epoch)

        return stats

def create_logger():
    """
    Create logger for the experiment

    Returns: List of logger objects

    """
    loggers = []
    names = ['train', 'val', 'test']
    
    # Determine task type
    if hasattr(cfg.train, 'mode') and cfg.train.mode == 'train_pin_prediction':
        task_type = 'pin_prediction'
    else:
        task_type = infer_task()
    
    for i, dataset in enumerate(range(cfg.share.num_splits)):
        loggers.append(CustomLogger(name=names[i], task_type=task_type))
    return loggers


def eval_spearmanr(y_true, y_pred):
    """Compute Spearman Rho averaged across tasks.
    """
    res_list = []

    if y_true.ndim == 1:
        res_list.append(stats.spearmanr(y_true, y_pred)[0])
    else:
        for i in range(y_true.shape[1]):
            # ignore nan values
            is_labeled = ~np.isnan(y_true[:, i])
            res_list.append(stats.spearmanr(y_true[is_labeled, i],
                                            y_pred[is_labeled, i])[0])

    return {'spearmanr': sum(res_list) / len(res_list)}
