import torch
import os
import numpy as np
import time

from tqdm import tqdm
from scipy.sparse import csgraph

from .MILDataset import MILDataset

class ProcessedMILDataset(MILDataset):
    """
    A MIL dataset that loads pre-processed data from disk. 
    - If the data is not found, it will be processed and saved. In this case, the data_path is needed.
    - If the adjacency matrix data is not found, it will be built and saved.
    """
    def __init__(self, 
                 processed_data_path : str,
                 data_path : str = None,
                 keep_in_memory : bool = False,
                 **kwargs,
                 ):
        
        if not hasattr(self, 'dataset_name'):
            self.dataset_name = 'ProcessedMILDataset'
        
        self.processed_data_path = processed_data_path
        self.data_path = data_path
        self.keep_in_memory = keep_in_memory
        self.processed = False

        super(ProcessedMILDataset, self).__init__(**kwargs)

        self.processed = self._check_already_processed()

        if not self.processed:
            if self.data_path is None:
                raise ValueError(f'[{self.dataset_name}] data_path needed to process data!')
            self._process_data()
            self.processed = True
        
        if self.data_shape is None:
            self.data_shape = self._compute_data_shape()

        # Loaded dict: { bag_name: True/False }
    
    def _compute_data_shape(self):
        if not self.processed:
            return None
        else:
            tmp = self._loader(os.path.join(self.processed_data_path, self.bag_names[0] + '.npy'))
            return tmp.shape[1:]

    def _check_already_processed(self):
        if not os.path.exists(self.processed_data_path):
            return False
        existing_bags = os.listdir(self.processed_data_path)
        existing_bags = [ bag.split('.')[0] for bag in existing_bags ]
        existing_bags = set(existing_bags)
        existing_bags = existing_bags.intersection(set(self.bag_names))
        print(f'[{self.dataset_name}] Found {len(existing_bags)} already processed bags')
        return len(existing_bags) == len(self.bag_names)
            
    def _process_data(self):
        if not os.path.exists(self.processed_data_path):
            os.makedirs(self.processed_data_path)
        pbar = tqdm(self.bag_names, total=len(self.bag_names))
        pbar.set_description(f'[{self.dataset_name}] Processing and saving data')
        for bag_name in pbar:
            bag_feat = self._load_bag_feat(bag_name)
            np.save(os.path.join(self.processed_data_path, bag_name + '.npy'), bag_feat)

    def _load_bag_feat(self, bag_name):

        if not self.processed:
            # If not processed, load instance by instance
            return super(ProcessedMILDataset, self)._load_bag_feat(bag_name)
        else:   
            bag_feat = self._loader(os.path.join(self.processed_data_path, bag_name + '.npy'))

            return bag_feat

    def _degree(self, index, edge_weight, num_nodes):
        """
        input:
            index: tensor (num_edges,)
            edge_weight: tensor (num_edges,)
        output:
            deg: tensor (num_nodes)
        """

        out = np.zeros((num_nodes))
        np.add.at(out, index, edge_weight)
        return out    
    
    def _add_self_loops(self, edge_index, num_nodes, edge_weight=None):
        """
        input:
            edge_index: tensor (2, num_edges)
        output:
            new_edge_index: tensor (2, num_edges + num_nodes)
        """

        loop_index = np.arange(0, num_nodes)
        loop_index = np.tile(loop_index, (2,1))

        if edge_index.shape[0] == 0:
            new_edge_index = loop_index
            new_edge_weight = np.ones(num_nodes)
        else:
            if edge_weight is None:
                edge_weight = np.ones(edge_index.shape[1])                 
            new_edge_index = np.hstack([edge_index, loop_index])
            new_edge_weight = np.concatenate([edge_weight, np.ones(num_nodes)])
        return new_edge_index, new_edge_weight
    
    def _remove_self_loops(self, edge_index, edge_weight=None):
        """
        input:
            edge_index: tensor (2, num_edges)
        output:
            new_edge_index: tensor (2, num_edges - num_nodes)
        """

        if edge_index.shape[0] == 0:
            return edge_index
        else:
            mask = edge_index[0] != edge_index[1]
            new_edge_index = edge_index[:,mask]
            if edge_weight is None:
                return new_edge_index
            else:
                return new_edge_index, edge_weight[mask]

    def _normalize_adj_matrix(self, edge_index, edge_weight, num_nodes):
        """
        input:
            edge_index: tensor (2, num_edges)
            edge_weight: tensor (num_edges)
        output:
            new_edge_index: tensor (2, num_edges + num_nodes)
            new_edge_weight: tensor (num_edges + num_nodes)
        """

        if edge_index.shape[0] == 0:
            new_edge_weight = np.array([])
        else:
            row = edge_index[0]
            col = edge_index[1]
            deg = self._degree(col, edge_weight, num_nodes).astype(np.float32)
            with np.errstate(divide='ignore'):
                deg_inv_sqrt = np.power(deg, -0.5)
                deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            
            # new_edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]
            new_edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

        return new_edge_weight

    def _build_laplacian_matrix(self, edge_index, num_nodes):
        """
        input:
            edge_index: tensor (2, num_edges)
        output:
            L_mat: tensor (num_nodes, num_nodes)
        """

        edge_weight = self._normalize_adj_matrix(edge_index, num_nodes)
        new_edge_index, new_edge_weight = self._add_self_loops(edge_index, num_nodes, -edge_weight)

        return new_edge_index, new_edge_weight

    def _get_max_bag_size(self):
        max_bag_size = 0
        pbar = tqdm(self.bag_names, total=len(self.bag_names))
        pbar.set_description(f'[{self.dataset_name}] Computing max bag size')
        for bag_name in pbar:
            bag_size = len(self.data_dict[bag_name]['inst_coords'])
            max_bag_size = max(max_bag_size, bag_size)
        return max_bag_size

    def __getitem__(self, index):        
        bag_name = self.bag_names[index]

        if 'bag_feat' in self.data_dict[bag_name]:
            bag_feat = self.data_dict[bag_name]['bag_feat']
        else:
            bag_feat = self._load_bag_feat(bag_name)
            if self.keep_in_memory:
                self.data_dict[bag_name]['bag_feat'] = bag_feat

        bag_label = self.data_dict[bag_name]['bag_label']
        inst_labels = self.data_dict[bag_name]['inst_labels']        
        bag_size = bag_feat.shape[0]

        if 'edge_index' in self.data_dict[bag_name]:
            edge_index = self.data_dict[bag_name]['edge_index']
            norm_edge_weight = self.data_dict[bag_name]['norm_edge_weight']
        else:
            coords = self.data_dict[bag_name]['inst_coords']
            edge_index, edge_weight = self._build_edge_index(coords, bag_feat)
            norm_edge_weight = self._normalize_adj_matrix(edge_index, edge_weight, bag_feat.shape[0])
            if bag_size == 1:
                edge_index, norm_edge_weight = self._add_self_loops(edge_index, bag_size, norm_edge_weight)
            self.data_dict[bag_name]['edge_index'] = edge_index
            self.data_dict[bag_name]['edge_weight'] = edge_weight
            self.data_dict[bag_name]['norm_edge_weight'] = norm_edge_weight
            
        adj_mat = torch.sparse_coo_tensor(edge_index, norm_edge_weight, (bag_size, bag_size)).coalesce().type(torch.float32)        

        return torch.from_numpy(bag_feat).type(torch.float32), torch.as_tensor(bag_label), torch.from_numpy(inst_labels), adj_mat