from torch.utils.data import DataLoader

from oucl.scenarios.samplers import DefaultSampler
from oucl.scenarios.transforms import DefaultTransform
from oucl.scenarios.collate_fns import DefaultCollate

from sklearn.metrics.cluster import contingency_matrix
from sklearn.metrics import normalized_mutual_info_score, adjusted_mutual_info_score, accuracy_score
import numpy as np

import wandb

class IncEvaluator():

    def __init__(self, dataset, super_inds, eval_inds, val_inds, config):

        self.super_inds = super_inds
        self.eval_inds = eval_inds
        self.val_inds = val_inds
        self.dataset = dataset

        self.validate = config.validate
        self.n_workers = config.scenario.num_workers

        self.class_perf = []
        self.clust_perf = []

        self.collate = DefaultCollate(DefaultTransform(config.dataset.img_size,
                                                       config.dataset.mean,
                                                       config.dataset.std))
        
        self.num_tasks = len(config.scenario.eval_task_classes)


        ## doesn't cause bug
        self.tasks = np.repeat(np.arange(self.num_tasks), config.scenario.eval_freq)


        self.step = 0

    def _get_super_loader(self, task):
        inds = []
        for t in range(task+1):
            inds += self.super_inds[t]
        self.n_classes = len(np.unique(self.dataset.labels[inds]))
        return DataLoader(self.dataset,
                          batch_size=256,
                          collate_fn=self.collate,
                          num_workers=self.n_workers,
                          sampler=DefaultSampler(inds, shuffle=True))

    def _get_eval_loader(self, task):
        inds = []
        for t in range(task+1):
            inds += self.eval_inds[t]
        return DataLoader(self.dataset,
                          batch_size=128,
                          collate_fn=self.collate,
                          num_workers=self.n_workers,
                          sampler=DefaultSampler(inds))
    
    def _get_val_loader(self, task):
        inds = []
        for t in range(task+1):
            inds += self.val_inds[t]
        return DataLoader(self.dataset,
                          batch_size=128,
                          collate_fn=self.collate,
                          num_workers=self.n_workers,
                          sampler=DefaultSampler(inds))


    def evaluate(self, model):

        task = self.tasks[self.step]

        super_loader = self._get_super_loader(task)
        if self.validate:
            print('Evaluating with Validation Set')
            eval_loader = self._get_val_loader(task)
        else:
            print('Evaluating with Evaluation Set')
            eval_loader = self._get_eval_loader(task)


        print(self.n_classes)
        print('Supervising KNN')
        model.supervise(super_loader)
        if hasattr(model, 'supervise_linear'):
            print('Supervising Linear')
            model.supervise_linear(super_loader, self.step)

        y_preds, y_true = model.classify(eval_loader)
        class_results = self.calculate_classification_accuracy(y_preds, 
                                                               y_true)

        y_preds, y_true = model.cluster(eval_loader)
        clust_results = self.calculate_clustering_metrics(y_preds, 
                                                          y_true)

        self.step += 1
        
        return class_results, clust_results


    def calculate_classification_accuracy(self, y_preds, y_true):
        results = {'eval_step': self.step}
        for model_name, preds in y_preds.items():
            # Overall accuracy
            overall_accuracy = accuracy_score(y_true, preds)
            
            # Per-class accuracy
            per_class_accuracy = {}
            for label in np.unique(y_true):
                label_indices = np.where(y_true == label)
                per_class_accuracy[int(label)] = accuracy_score(y_true[label_indices], 
                                                                preds[label_indices])
            
            
            results[str(model_name)] = {
                'overall_accuracy': overall_accuracy,
                'per_class_accuracy': per_class_accuracy,
            }

        
        wandb.log(results)
        self.class_perf.append(results)
        
        return results

    def calculate_clustering_metrics(self, y_preds, y_true):
        def purity_score(y_true, y_pred):
            # compute contingency matrix (also called confusion matrix)
            conting_matrix = contingency_matrix(y_true, y_pred)
            # return purity
            return np.sum(np.amax(conting_matrix, axis=0)) / np.sum(conting_matrix)

        results = {'eval_step': self.step}
        for model_name, preds in y_preds.items():
            # Overall metrics
            overall_purity = purity_score(y_true, preds)
            overall_nmi = normalized_mutual_info_score(y_true, preds)
            overall_ami = adjusted_mutual_info_score(y_true, preds)
            
            
            results[str(model_name)] = {
                'overall_purity': overall_purity,
                'overall_nmi': overall_nmi,
                'overall_ami': overall_ami,
            }
        
        self.clust_perf.append(results)
        wandb.log(results)

        return results


class TaskEvaluator():

    def __init__(self, dataset, super_inds, eval_inds, val_inds, config):

        self.super_inds = super_inds
        self.eval_inds = eval_inds
        self.val_inds = val_inds
        self.dataset = dataset

        self.validate = config.validate
        self.n_workers = config.scenario.eval_workers

        self.class_perf = []
        self.clust_perf = []

        self.collate = DefaultCollate(DefaultTransform(config.dataset.img_size,
                                                       config.dataset.mean,
                                                       config.dataset.std))
        
        self.num_tasks = len(config.scenario.eval_task_classes)
        #self.tasks = torch.repeat_interleave(torch.arange(self.num_tasks), config.scenario.eval_freq)


        self.step = 0

    def _get_super_loader(self):
        return [DataLoader(self.dataset,
                          batch_size=256,
                          collate_fn=self.collate,
                          num_workers=self.n_workers,
                          sampler=DefaultSampler(self.super_inds[t], shuffle=True)) for t in range(self.num_tasks)]

    def _get_eval_loader(self):
        return [DataLoader(self.dataset,
                          batch_size=128,
                          collate_fn=self.collate,
                          num_workers=self.n_workers,
                          sampler=DefaultSampler(self.eval_inds[t])) for t in range(self.num_tasks)]
    
    def _get_val_loader(self):
        return [DataLoader(self.dataset,
                          batch_size=128,
                          collate_fn=self.collate,
                          num_workers=self.n_workers,
                          sampler=DefaultSampler(self.val_inds[t])) for t in range(self.num_tasks)]


    def evaluate(self, model):
        super_loaders = self._get_super_loader()
        if self.validate:
            print('Evaluating with Validation Set')
            eval_loaders = self._get_val_loader()
        else:
            print('Evaluating with Evaluation Set')
            eval_loaders = self._get_eval_loader()

        class_results = []
        clust_results = []
        results = {'eval_step': self.step}
        for t in range(len(super_loaders)):
            print(f'Evaluating Task: {t}')
            model.supervise(super_loaders[t])
            if hasattr(model, 'supervise_linear'):
                model.supervise_linear(super_loaders[t], self.step)

            y_preds, y_true = model.classify(eval_loaders[t])
            class_res = self.calculate_classification_accuracy(y_preds, 
                                                                y_true)

            y_preds, y_true = model.cluster(eval_loaders[t])
            clust_res = self.calculate_clustering_metrics(y_preds, 
                                                            y_true)
            class_res.update(clust_res)
            results[f'task_{t}'] = class_res

        wandb.log(results)

        self.step += 1
        super_loaders = []
        eval_loaders = []
        return class_results, clust_results


    def calculate_classification_accuracy(self, y_preds, y_true):
        results = {}
        for model_name, preds in y_preds.items():
            # Overall accuracy
            results[str(model_name)] = accuracy_score(y_true, preds)
            
        return results

    def calculate_clustering_metrics(self, y_preds, y_true):
        def purity_score(y_true, y_pred):
            # compute contingency matrix (also called confusion matrix)
            conting_matrix = contingency_matrix(y_true, y_pred)
            # return purity
            return np.sum(np.amax(conting_matrix, axis=0)) / np.sum(conting_matrix)

        results = {}
        for model_name, preds in y_preds.items():
            # Overall metrics
            overall_purity = purity_score(y_true, preds)
            overall_nmi = normalized_mutual_info_score(y_true, preds)
            overall_ami = adjusted_mutual_info_score(y_true, preds)
            
            results[str(model_name)] = overall_purity
                


        return results
    

class TaskEvaluator2():

    def __init__(self, dataset, super_inds, eval_inds, val_inds, config):

        self.super_inds = super_inds
        self.eval_inds = eval_inds
        self.val_inds = val_inds
        self.dataset = dataset

        self.validate = config.validate
        self.n_workers = config.scenario.num_workers

        self.class_perf = []
        self.clust_perf = []

        self.collate = DefaultCollate(DefaultTransform(config.dataset.img_size,
                                                       config.dataset.mean,
                                                       config.dataset.std))
        
        self.num_tasks = len(config.scenario.eval_task_classes)
        self.tasks = np.repeat(np.arange(self.num_tasks), config.scenario.eval_freq)


        self.step = 0

    def _get_super_loader(self, task):
        inds = []
        for t in range(task+1):
            inds += self.super_inds[t]
        self.n_classes = len(np.unique(self.dataset.labels[inds]))
        return DataLoader(self.dataset,
                          batch_size=256,
                          collate_fn=self.collate,
                          num_workers=self.n_workers,
                          persistent_workers=True,
                          pin_memory=True,
                          sampler=DefaultSampler(inds, shuffle=True))

    def _get_eval_loader(self, task):
        inds = []
        for t in range(task+1):
            inds += self.eval_inds[t]
        return DataLoader(self.dataset,
                          batch_size=128,
                          collate_fn=self.collate,
                          num_workers=self.n_workers,
                          sampler=DefaultSampler(inds))
    
    def _get_val_loader(self, task):
        inds = []
        for t in range(task+1):
            inds += self.val_inds[t]
        return DataLoader(self.dataset,
                          batch_size=128,
                          collate_fn=self.collate,
                          num_workers=self.n_workers,
                          sampler=DefaultSampler(inds))



    def evaluate(self, model):
        task = self.tasks[self.step]
        super_loaders = self._get_super_loader(task)
        if self.validate:
            print('Evaluating with Validation Set')
            eval_loaders = self._get_val_loader(task)
        else:
            print('Evaluating with Evaluation Set')
            eval_loaders = self._get_eval_loader(task)

        class_results = []
        clust_results = []
        results = {'eval_step': self.step}
        for t in range(len(super_loaders)):
            print(f'Evaluating Task: {t}')
            model.supervise(super_loaders[t])
            model.supervise_linear(super_loaders[t], self.step)

            y_preds, y_true = model.classify(eval_loaders[t])
            class_res = self.calculate_classification_accuracy(y_preds, 
                                                                y_true)

            y_preds, y_true = model.cluster(eval_loaders[t])
            clust_res = self.calculate_clustering_metrics(y_preds, 
                                                            y_true)
            class_res.update(clust_res)
            results[f'task_{t}'] = class_res
        
        wandb.log(results)

        self.step += 1
        return class_results, clust_results


    def calculate_classification_accuracy(self, y_preds, y_true):
        results = {}
        for model_name, preds in y_preds.items():
            # Overall accuracy
            results[str(model_name)] = accuracy_score(y_true, preds)
            

            
            
        return results

    def calculate_clustering_metrics(self, y_preds, y_true):
        def purity_score(y_true, y_pred):
            # compute contingency matrix (also called confusion matrix)
            conting_matrix = contingency_matrix(y_true, y_pred)
            # return purity
            return np.sum(np.amax(conting_matrix, axis=0)) / np.sum(conting_matrix)

        results = {}
        for model_name, preds in y_preds.items():
            # Overall metrics
            overall_purity = purity_score(y_true, preds)
            overall_nmi = normalized_mutual_info_score(y_true, preds)
            overall_ami = adjusted_mutual_info_score(y_true, preds)
            
            results[str(model_name)] = overall_purity
                


        return results

def load_evaluator(dataset, super_inds, eval_inds, val_inds, config):
    if config.scenario.split_tasks:
        return TaskEvaluator(dataset, super_inds, eval_inds, val_inds, config)
    else:
        return IncEvaluator(dataset, super_inds, eval_inds, val_inds, config)
