import time 
import numpy as np
import logging
import importlib 
import pickle
import torch
from Backbones.model_factory import get_classifier
from Backbones.neighborhood_handling import NeighborhoodProcessor, MixedNeighborhoodProcessor
from utils.dataset import IncrementalGraph
from utils.misc import shuffle_tensor
import importlib

def get_pipeline(args):
    if args.IL_stream == 'classIL':
        if args.method == 'joint':
            return ClassIncremental_jointPipeline
        return ClassIncremental_OCGLPipeline
    if args.IL_stream == 'timeIL':
        if args.method == 'joint':
            return TimeIncremental_jointPipeline
        return TimeIncremental_OCGLPipeline
    raise ValueError(f"Unknown IL stream type: {args.IL_stream}")

class BaseOCGLPipeline: 
    def __init__(self, args, graph_dataset, valid=True): 
        self.args = args 
        self.valid = valid 
        self.anytime_eval = not self.valid and self.args.anytime_eval
        self.dataset = self.initialize_environment(graph_dataset) 
        self.ocglearner = self.initialize_model() 
        self.perf_matrix = np.zeros((args.n_tasks, args.n_tasks)) 
        self.all_perf_batch = []
        self.batch_count = 0

    def initialize_environment(self, graph_dataset):
        raise NotImplementedError("This method should be overridden in subclasses")

    def initialize_model(self):
        classifier = get_classifier(self.args).cuda(self.args.gpu)
        if self.args.backbone == 'UMIXED':
            neighborhood_processor = MixedNeighborhoodProcessor(self.args)
        else:
            neighborhood_processor = NeighborhoodProcessor(self.args)
        baseline = importlib.import_module(f'Baselines.{self.args.method}')
        ocglearner = baseline.OCGLearner(classifier, neighborhood_processor, self.args)
        return ocglearner

    def get_task_batches(self, task_info):
        raise NotImplementedError("This method should be overridden in subclasses")

    def process_task(self, task, task_info):
        task_batches = self.get_task_batches(task_info)

        for mini_batch in task_batches:
            subgraph, train_ids_batch = self.dataset.update_subgraph(node_ids=mini_batch, device=f'cuda:{self.args.gpu}', task=task)
            labels = subgraph.dstdata['label'].view(-1)

            if len(train_ids_batch) == 0:
                continue
            start = time.time()
            self.ocglearner.observe(subgraph, train_ids_batch, labels[train_ids_batch])
            self.time_tr += time.time() - start
            self.empty_gpu_cache()

            self.batch_count += 1
            if self.anytime_eval and self.batch_count % self.args.anytime_eval_freq == 0:
                self.evaluate_model(subgraph, labels, task, end_batch=True)

        if not self.valid or task == self.args.n_tasks - 1:
            self.evaluate_model(subgraph, labels, task, end_batch=False)

    def empty_gpu_cache(self, threshold=0.66):
        if self.args.dataset in ['Reddit', 'RomanEmpire']:
            threshold = 0.8
        gpu_memory = torch.cuda.memory_reserved(self.args.gpu)
        total_memory = torch.cuda.get_device_properties(self.args.gpu).total_memory
        if gpu_memory > total_memory * threshold:
            torch.cuda.empty_cache()

    def evaluate_model(self, subgraph, labels, task, end_batch=False):
        start = time.time()
        splt = 1 if self.valid or end_batch else 2
        test_ids = self.get_test_ids(subgraph, splt)
        predictions = self.get_model_predictions(subgraph, test_ids)
        perf_per_task = np.zeros(self.args.n_tasks + 1) # +1 for the tasks column
        perf_per_task[-1] = task
        for t in range(task + 1):
            if self.args.dataset == 'Elliptic':
                perf = self.compute_task_f1_score(subgraph, labels, test_ids, predictions, t)
            else:
                perf = self.compute_task_accuracy(subgraph, labels, test_ids, predictions, t)
            perf_per_task[t] = perf
        self.time_te += time.time() - start
        if end_batch:
            self.all_perf_batch.append(perf_per_task)
        else:
            self.log_evaluation_results(perf_per_task, task)

    def get_test_ids(self, subgraph, splt):
        test_ids = torch.nonzero(subgraph.ndata['split'] == splt, as_tuple=True)[0]
        shuffled_indices = torch.randperm(test_ids.size(0))
        test_ids = test_ids[shuffled_indices]
        return test_ids

    def get_model_predictions(self, subgraph, test_ids):
        batches = torch.split(test_ids, self.args.batch_size)
        predictions = torch.tensor([], device=subgraph.device)
        for batch in batches:
            batch_predictions = self.ocglearner.predict_labels(subgraph, batch)
            predictions = torch.cat((predictions, batch_predictions), dim=0)
            self.empty_gpu_cache()
        return predictions

    def compute_task_accuracy(self, subgraph, labels, test_ids, predictions, t):
        task_mask = subgraph.ndata['task'][test_ids] == t
        test_ids_t = test_ids[task_mask]
        if len(test_ids_t) == 0:
            return 0
        labels_t = labels[test_ids_t]
        task_predictions = predictions[task_mask]
        acc = torch.sum(task_predictions == labels_t).item() / len(labels_t)
        return acc

    def compute_task_f1_score(self, subgraph, labels, test_ids, predictions, t):
        task_mask = subgraph.ndata['task'][test_ids] == t
        test_ids_t = test_ids[task_mask]
        if len(test_ids_t) == 0:
            return 0
        labels_t = labels[test_ids_t]
        task_predictions = predictions[task_mask]
        tp = torch.sum((task_predictions == 1) & (labels_t == 1)).item()
        fp = torch.sum((task_predictions == 1) & (labels_t == 0)).item()
        fn = torch.sum((task_predictions == 0) & (labels_t == 1)).item()
        precision = tp / (tp + fp) if tp + fp > 0 else 0
        recall = tp / (tp + fn) if tp + fn > 0 else 0
        f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
        return f1_score
    
    def log_evaluation_results(self, perf_per_task, task):
        for t in range(task + 1):
            self.perf_matrix[task, t] = round(perf_per_task[t] * 100, 2)
        results = "".join([f"T{t:02d} {self.perf_matrix[task, t]:.2f}|" for t in range(task + 1)])
        perf_mean = round(np.mean(perf_per_task[:task + 1]) * 100, 2)
        results += f" perf_mean: {perf_mean} "
        results += f"time: {round(time.time() - self.start, 2)}s "
        results += f"tr: {round(self.time_tr, 2)}s "
        results += f"te: {round(self.time_te, 2)}s"
        logging.info(results)
    
    def save_task_model(self, task):
        save_model_path = f'{self.args.current_model_save_path}_{task}.pkl'
        with open(save_model_path, 'wb') as f:
            pickle.dump(self.ocglearner, f)

    def calculate_metrics(self):
        avg_perf = self.compute_average_performance()
        avg_forg = self.compute_average_forgetting() if not self.valid else None
        if avg_forg is not None:
            self.log_metrics_and_time(avg_perf, avg_forg)
        return avg_perf, avg_forg

    def compute_average_performance(self):
        return round(np.mean(self.perf_matrix[self.args.n_tasks - 1, :]), 2)

    def compute_average_forgetting(self):
        forgetting = [
            self.perf_matrix[self.args.n_tasks - 1][t] - self.perf_matrix[t][t]
            for t in range(self.args.n_tasks - 1)
        ]
        return round(np.mean(forgetting), 2)

    def log_metrics_and_time(self, avg_perf, avg_forg):
        logging.info(f'AP: {avg_perf}')
        logging.info(f'AF: {avg_forg}')
        logging.info(f'Total time: {round(time.time() - self.start, 2)}s\n')

    def run(self):
        self.initialize_run_timers()
        for task, task_info in enumerate(self.args.task_seq):
            self.process_task(task, task_info)
            if self.args.save_models:
                self.save_task_model(task)
        avg_perf, avg_forg = self.calculate_metrics()
        if self.anytime_eval:
            self.save_anytime_evaluation_results()
        return avg_perf, avg_forg, self.perf_matrix
    
    def initialize_run_timers(self):
        self.start = time.time()
        self.time_tr = 0
        self.time_te = 0

    def save_anytime_evaluation_results(self):
        batch_perf = np.vstack(self.all_perf_batch)
        np.save(f'{self.args.current_model_save_path.replace("models", "batch_perf")}.npy', batch_perf)

class ClassIncremental_OCGLPipeline(BaseOCGLPipeline):
    def __init__(self, args, graph_dataset, valid=True):
        super().__init__(args, graph_dataset, valid)

    def initialize_environment(self, graph_dataset):
        torch.cuda.set_device(self.args.gpu)
        dataset = IncrementalGraph(graph_dataset)
        self.args.d_data, self.args.n_cls = dataset.d_data, dataset.n_cls
        cls = [list(range(i, i + self.args.n_cls_per_task)) for i in range(0, self.args.n_cls-1, self.args.n_cls_per_task)]
        self.args.task_seq = cls[:self.args.n_validation_tasks] if self.valid else cls
        self.args.n_tasks = len(self.args.task_seq)
        return dataset
    
    def get_task_batches(self, task_cls):
        new_nodes_ids = torch.where(torch.isin(self.dataset.graph.ndata['label'], torch.tensor(task_cls)))[0]
        new_nodes_ids = shuffle_tensor(new_nodes_ids, random_seed=42)
        task_batches = torch.split(new_nodes_ids, self.args.n_nodes_per_batch)
        return task_batches

class ClassIncremental_jointPipeline(ClassIncremental_OCGLPipeline):
    def __init__(self, args, graph_dataset, valid=True):
        super().__init__(args, graph_dataset, valid)
        self.anytime_eval = False
    
    def process_task(self, task, task_cls):
        self.ocglearner.net.reset_params()

        new_nodes_ids = torch.where(torch.isin(self.dataset.graph.ndata['label'], torch.tensor(task_cls)))[0]
        subgraph, _ = self.dataset.update_subgraph(node_ids=new_nodes_ids, device=f'cuda:{self.args.gpu}', task=task)
        train_ids = torch.where(subgraph.ndata['split'] == 0)[0]
        labels = subgraph.dstdata['label'].view(-1)

        if not self.valid or task == self.args.n_tasks - 1:
            start = time.time()
            self.ocglearner.observe(subgraph, train_ids, labels[train_ids])
            self.time_tr += time.time() - start
            self.evaluate_model(subgraph, labels, task, end_batch=False)

class TimeIncremental_OCGLPipeline(BaseOCGLPipeline):
    def __init__(self, args, graph_dataset, valid=True):
        super().__init__(args, graph_dataset, valid)

    def initialize_environment(self, graph_dataset):
        torch.cuda.set_device(self.args.gpu)
        dataset = IncrementalGraph(graph_dataset)
        self.args.d_data, self.args.n_cls = dataset.d_data, dataset.n_cls
        timestamps = dataset.graph.ndata['time']
        shuffled_timestamps, shuffled_indices = shuffle_tensor(timestamps, return_indices=True, random_seed=42)
        sorted_indices = torch.argsort(shuffled_timestamps, stable=True)
        final_indices = shuffled_indices[sorted_indices]
        split_by_task = torch.chunk(final_indices, self.args.n_time_tasks)
        self.args.task_seq = split_by_task[:self.args.n_validation_tasks] if self.valid else split_by_task
        self.args.n_tasks = len(self.args.task_seq)
        return dataset

    def get_task_batches(self, task_node_ids):
        task_batches = torch.split(task_node_ids, self.args.n_nodes_per_batch)
        return task_batches

class TimeIncremental_jointPipeline(TimeIncremental_OCGLPipeline):
    def __init__(self, args, graph_dataset, valid=True):
        super().__init__(args, graph_dataset, valid)
        self.anytime_eval = False
    
    def process_task(self, task, task_node_ids):
        self.ocglearner.net.reset_params()

        subgraph, _ = self.dataset.update_subgraph(node_ids=task_node_ids, device=f'cuda:{self.args.gpu}', task=task)
        train_ids = torch.where(subgraph.ndata['split'] == 0)[0]
        labels = subgraph.dstdata['label'].view(-1)

        if not self.valid or task == self.args.n_tasks - 1:
            start = time.time()
            self.ocglearner.observe(subgraph, train_ids, labels[train_ids])
            self.time_tr += time.time() - start
            self.evaluate_model(subgraph, labels, task, end_batch=False)
