from torch.utils.data import Dataset
import scanpy as sc
import numpy as np
import scipy
from loguru import logger
from tqdm import tqdm
import random
import torch
import numpy as np
from scipy.spatial import cKDTree
import pandas as pd

def calculate_regionidx(adata, round0=10,round1=10):
    rawcoord = adata.obs[['array_col', 'array_row']].values.astype(float)
    nbordict = {}
    for coord in rawcoord:
        mask = (rawcoord[:,0]<(coord[0]+round0))&(rawcoord[:,0]>(coord[0]-round0))&(rawcoord[:,1]<(coord[1]+round1))&(rawcoord[:,1]>(coord[1]-round1))
        nbordict[tuple(np.round(coord,5))] = np.nonzero(mask)[0].tolist()
        
    goodpt = [len(nbordict[coord_tuple]) for coord_tuple in map(tuple, np.round(rawcoord,5))]

    string_nbordict = {'_'.join(map(str, key)): value for key, value in nbordict.items()}
    adata.uns['nbordict'] = string_nbordict
    adata.obs['goodpt'] = goodpt

    return adata

def calculate_missingpt(adata, orderidx):
    rawcoord = adata.obs[['array_col', 'array_row']].values.astype(float)
    rawset = set(map(tuple, rawcoord))
    rawdict = {tuple(coord): idx for idx, coord in enumerate(rawcoord)}

    nbordict = {}
    for coord in rawcoord:
        coord_tuple = tuple(coord)
        nbordict[coord_tuple] = []
        tmporderidx = coord + orderidx
        for tmp_coord in tmporderidx:
            tmp_tuple = tuple(tmp_coord)
            if tmp_tuple in rawset:
                nbordict[coord_tuple].append(rawdict[tmp_tuple])
            else:
                nbordict[coord_tuple].append(len(rawcoord))
    
    goodpt = [np.sum([x == len(rawcoord) for x in nbordict[coord_tuple]]) for coord_tuple in map(tuple, rawcoord)]

    string_nbordict = {'_'.join(map(str, key)): value for key, value in nbordict.items()}
    adata.uns['nbordict'] = string_nbordict
    adata.obs['goodpt'] = goodpt

    return adata

def find_dataset_and_index(a, datasets):
    cumulative_sizes = np.cumsum(datasets)
    
    dataset_index = np.searchsorted(cumulative_sizes, a, side='right')
    
    in_dataset_index = a - (cumulative_sizes[dataset_index - 1] if dataset_index > 0 else 0)

    return dataset_index, in_dataset_index
 
from scipy.spatial import KDTree
def get_avg_mindistance(array):
    tree = KDTree(array)

    distances, _ = tree.query(array, k=2)  

    nearest_neighbor_distances = distances[:, 1]

    median_distance = np.median(nearest_neighbor_distances)

    return median_distance

class SpatialTarget(Dataset):
    def __init__(self, adata,exp_bins=50,cfg=None,shuffle='random',getsingle=False,codebook=None):

        self.adata = adata.copy()
        self.exp_bins = exp_bins
        rawcoord = self.adata.obs[['array_col', 'array_row']].values.astype(float)
        self.rawcoord = rawcoord
        self.block_size = cfg.block_size
        self.ngenes = adata.shape[1]
        self.shuffle = shuffle
        self.getsingle = getsingle
        self.genemean = None
        self.genestd = None

        if not getsingle:
            self.adata = calculate_regionidx(self.adata,cfg.round0,cfg.round1)
            self.nbordict = {tuple(map(float, key.split('_'))): value for key, value in self.adata.uns['nbordict'].items()}
            self.l = get_avg_mindistance(self.rawcoord)

            if cfg.master_process:
                logger.info(f'shape {adata.shape}')
                logger.info(f"Mid distance:{self.l}")

        if cfg.bin_type=='incell':
            self.get_bin=self.get_bin_exp
        elif cfg.bin_type=='fix':
            self.get_bin=self.get_bin_exp_fix
        elif cfg.bin_type=='ingene':
            self.get_bin=self.get_bin_exp_ingene
        elif cfg.bin_type=='nobin':
            self.get_bin=self.naivebin
        else:
            exit(0)

        if cfg.task=='quantize':
            self.get_bin=self.get_quantize
            self.codebook = codebook

        if cfg.task=='zinb':
            self.adata.X = self.adata.layers['rawcount']
        
        if cfg.zscore:
            sc.pp.scale(self.adata)
            self.genemean = self.adata.var['mean'].values
            self.genestd = self.adata.var['std'].values

        if scipy.sparse.issparse(self.adata.X):
            rawX = self.adata.X.toarray()
        else:
            rawX = self.adata.X

        self.expappend = rawX
        self.expappend = self.get_bin(self.expappend)


    def __len__(self):
        return self.rawcoord.shape[0]
    
    def get_bin_exp_fix(self,exp): # the problem is that for some gene, the gene expression range is squeezed in the small interval. In other words, the transformed labels are always e.g. 10, 9. 
    #     bins = [0.01, 0.98469985, 1.08962703, 1.17639256, 1.26285386,
    #    1.35545528, 1.46813452, 1.64222777, 1.92745376, 2.51580477,
    #    10]
        bins = [-0.363387913,
                -0.242789909,
                -0.167018071,
                -0.100853093,
                -0.0578530394,
                -0.0362497307,
                -0.0234599896,
                -0.01,
                0.0,
                100]
        binned_expr = np.digitize(exp, bins,right=True)
        return binned_expr
    
    def get_bin_exp_ingene(self,exp):
        exp_bins = 10
        max_expr = np.max(exp,axis=0)
        bins = [np.linspace(0, max_expr[i], exp_bins) for i in range(len(max_expr))]
        binned_expr = np.array([np.digitize(exp[:,i], bins[i],right=True) for i in range(exp.shape[1])])
        return binned_expr.T
    
    def get_bin_exp(self,exp): # in each cell
        exp_bins = 10
        max_expr = np.max(exp,axis=1)
        bins = [np.linspace(0, max_expr[i], exp_bins) for i in range(len(max_expr))]
        binned_expr = np.array([np.digitize(exp[i], bins[i],right=True) for i in range(len(exp))])
        return binned_expr
    
    def naivebin(self,exp):
        return exp

    def get_quantize(self,exp):
        """
        self.codebook size: [ncells,ngenes]
        exp size: [ncells, ngenes]
        Use quantize to make each cell find a nearest codebook, using cosine similarity as metric
        return: [ncells,ngenes] but each cell is retrieved from the codebook
        """
        mapping,codebook = self.codebook[0],self.codebook[1]
        nbrs,mean,pcaproj = mapping
        _, maxid = nbrs.kneighbors((exp - mean)@pcaproj)
        return maxid

    def __getitem__(self, idx):
        if self.getsingle:
            return self._getOne_(idx)
        else:
            return self._getSeqence_(idx)
    
    def _getOne_(self, idx):
        rawexp = self.expappend[idx,:]
        return torch.tensor(rawexp)
    
    def _getSeqence_(self, idx):
        nboridx = self.nbordict[tuple(np.round(self.rawcoord[idx,:],5))]
        shuffled_list = nboridx
        if self.shuffle=='random':
            random.shuffle(shuffled_list)
        elif self.shuffle=='corner':
            neipt = self.rawcoord[nboridx]
            condition = np.random.choice([0,1,2,3])
            xmin,ymin,xmax,ymax = neipt[:,0].min(),neipt[:,1].min(),neipt[:,0].max(),neipt[:,1].max()
            cornerlist = np.array([[xmin,ymin],[xmin,ymax],[xmax,ymin],[xmax,ymax]])
            corner = cornerlist[condition]
            distances = [np.linalg.norm(corner - point) for point in neipt]
            weights = [np.exp(-d/self.l) if d != 0 else 10 for d in distances]  
            shuffled_list = np.random.choice(nboridx, size=len(nboridx), replace=False, p=weights/np.sum(weights))
        mask = torch.ones([len(shuffled_list)]).bool()
        nborcoord = (self.rawcoord[shuffled_list] - self.rawcoord[idx])+30
        rawexp = self.expappend[shuffled_list,:]
        return torch.tensor(rawexp), torch.tensor(nborcoord), mask
    
class MultiSpatialTarget(Dataset):
    def __init__(self, filelist,exp_bins=50,cfg=None,shuffle='random',getsingle=False,codebook=None):
        ### Random shuffle and Load
        raw_sum = 0
        self.allseqdata = []
        self.lenlist=[]
        if cfg.master_process:
                logger.info(f'scale: {cfg.zscore}')
                if cfg.task=='quantize':
                    logger.info(f'quantize: {cfg.task}')
                    logger.info(f'codebook size: {codebook[1].shape}')
        for f in filelist:
            if cfg.master_process:
                logger.info(f'read file {f}')
            tmpadata = sc.read_h5ad(f)
            tmpseqdata = SpatialTarget(tmpadata,exp_bins=exp_bins,cfg=cfg,shuffle=shuffle,getsingle=getsingle,codebook=codebook)
            self.allseqdata.append(tmpseqdata)
            raw_sum = raw_sum + tmpadata.shape[0]
            self.lenlist.append(len(tmpseqdata))
        self.len_sum = np.sum(self.lenlist)
        if cfg.master_process:
            logger.info(f'pre/after process shape {raw_sum}/{self.len_sum}')
        self.getsingle = getsingle
        
    def __len__(self):
        return  self.len_sum
    
    def _getSeqence_(self, idx):
        i,subidx = find_dataset_and_index(idx, self.lenlist)
        nborexp, nborcoord, mask = self.allseqdata[i].__getitem__(subidx)
        return nborexp, nborcoord, mask
    
    def _getOne_(self, idx):
        i,subidx = find_dataset_and_index(idx, self.lenlist)
        nborexp = self.allseqdata[i].__getitem__(subidx)
        return nborexp

    def __getitem__(self, idx):
        if self.getsingle:
            return self._getOne_(idx)
        else:
            return self._getSeqence_(idx)

class SpatialAnnotation(SpatialTarget):
    def __init__(self, adata,exp_bins=50,cfg=None,shuffle='random',getsingle=False,codebook=None,strlabel=None):

        if adata.obs.get(cfg.labelname) is None:
            annotation = pd.read_csv('/PATH/MERFISH/Annotation/cluster_to_cluster_annotation_membership_pivoted.csv')
            annotation = annotation.set_index('cluster_alias')
            query = annotation.loc[adata.obs.cluster_alias.values,:]
            query.index = adata.obs.index
            adata.obs = pd.concat([adata.obs, query],axis=1)

        super(SpatialAnnotation, self).__init__(adata,exp_bins=exp_bins,cfg=cfg,shuffle=shuffle,getsingle=getsingle,codebook=codebook)

        for k in self.nbordict.keys():
            idx = self.nbordict[k][0]
            self.nbordict[k].append(idx)

        if strlabel is None:
            from sklearn.preprocessing import LabelEncoder
            label_encoder = LabelEncoder()
            self.label = label_encoder.fit_transform(adata.obs[cfg.labelname].values)
            strlabel = label_encoder.classes_
            label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
            if cfg.master_process:
                logger.info(label_mapping)
        else:
            rawlabel = adata.obs[cfg.labelname].values.tolist()
            self.label = [strlabel.index(label) for label in rawlabel]
        self.nclass = len(strlabel)
    
    def _getSeqence_(self, idx):
        nborexp, nborcoord, mask = super(SpatialAnnotation, self)._getSeqence_(idx)
        return nborexp, nborcoord, mask, torch.tensor(self.label[idx])
    
class MultiSpatialAnnotation(MultiSpatialTarget):
    def __init__(self, filelist,exp_bins=50,cfg=None,shuffle='random',getsingle=False,codebook=None,strlabel=None):
        if strlabel is None:
            logger.error('must provide strlabel for multiple datasets annotation task')
        ### Random shuffle and Load
        raw_sum = 0
        self.allseqdata = []
        self.lenlist=[]
        if cfg.master_process:
                logger.info(f'scale: {cfg.zscore}')
                if cfg.task=='quantize':
                    logger.info(f'quantize: {cfg.task}')
                    logger.info(f'codebook size: {codebook[1].shape}')
                if strlabel is not None:
                    logger.info(f'strlabel: {strlabel} \n label size: {len(strlabel)} ')
        for f in filelist:
            if cfg.master_process:
                logger.info(f'read file {f}')
            tmpadata = sc.read_h5ad(f)
            tmpseqdata = SpatialAnnotation(tmpadata,exp_bins=exp_bins,cfg=cfg,shuffle=shuffle,getsingle=getsingle,codebook=codebook,strlabel=strlabel)
            self.allseqdata.append(tmpseqdata)
            raw_sum = raw_sum + tmpadata.shape[0]
            self.lenlist.append(len(tmpseqdata))
        self.len_sum = np.sum(self.lenlist)
        if cfg.master_process:
            logger.info(f'pre/after process shape {raw_sum}/{self.len_sum}')
        self.getsingle = getsingle
        self.nclass = len(strlabel)
    
    def _getSeqence_(self, idx):
        i,subidx = find_dataset_and_index(idx, self.lenlist)
        nborexp, nborcoord, mask, label = self.allseqdata[i].__getitem__(subidx)
        return nborexp, nborcoord, mask, label
