import os
import yaml
import time
import torch
import numpy as np
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
from loguru import logger
from cfg import build_cfg
from contextlib import nullcontext
from model import DaoConfig, scDaoModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from dataset import SpatialSeq,SpatialTarget
import scanpy as sc
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import scipy
from collections import deque

def points_in_a_not_in_b(a, b):
    a_rows = set(map(tuple, a))
    b_rows = set(map(tuple, b))
    diff = a_rows - b_rows
    return np.array(list(diff))

def expandright(rawcoord):
    right0 = rawcoord + np.array([1, 1])
    right1 = rawcoord + np.array([1, -1])
    return np.unique(np.vstack([right0, right1]), axis=0)

def recursive_expand(rawcoord, N=1, pattern='r&u'):
    if pattern == 'r':
        ymax, ymin = rawcoord[:, 1].max(), rawcoord[:, 1].min()
        for _ in range(N):
            rawcoord = expandright(rawcoord)
            rawcoord = rawcoord[(rawcoord[:, 1] <= ymax) & (rawcoord[:, 1] >= ymin)]
        return rawcoord
    else:
        for _ in range(N):
            rawcoord = expandright(rawcoord)
        return np.unique(rawcoord, axis=0)

def update_nbordict(adata, extendcoord, orderidx):
    rawcoord = adata.obs[['array_col', 'array_row']].values.astype(float)
    rawset = set(map(tuple, rawcoord))
    rawdict = {tuple(coord): idx for idx, coord in enumerate(rawcoord)}

    nbordict = {}
    for coord in extendcoord:
        coord_tuple = tuple(coord)
        nbordict[coord_tuple] = []
        tmporderidx = coord + orderidx
        for tmp_coord in tmporderidx:
            tmp_tuple = tuple(tmp_coord)
            if tmp_tuple in rawset:
                nbordict[coord_tuple].append(rawdict[tmp_tuple])
            else:
                nbordict[coord_tuple].append(len(rawcoord))
    string_nbordict = {'_'.join(map(str, key)): value for key, value in nbordict.items()}
    adata.uns['nbordict'] = string_nbordict
    adata.obs['goodpt'] = - 1

    return adata

def imputeTarget(adatahvg,model,cfg,N=7):
    hvgcoord = adatahvg.obs[['array_col', 'array_row']].values.astype(float)
    cell_ds = SpatialTarget(adatahvg,cfg=cfg)
    allpred=[]
    allemb=[]
    allcoord=[]
    for index in tqdm(range(hvgcoord.shape[0])):
        Qspoint = hvgcoord[index]
        tmpdata = cell_ds[index]
        tmpcoord = tmpdata[1].cpu().numpy()
        if tmpcoord.shape[0]<N+1:
            continue
        rawlen = tmpdata[0].shape[0]

        with torch.no_grad():
            logit, tmpemb = model(inputs_embeds=tmpdata[0].unsqueeze(0).cuda(),coord=tmpdata[1].unsqueeze(0).cuda().float(),mask=tmpdata[2].unsqueeze(0).cuda(),embedding=True)
        logit = logit[0,rawlen-1+N:-1]
        logit[logit<0.1]=0
        allpred.append(logit.detach().cpu().numpy())
        allemb.append(tmpemb[0,rawlen-1+N:-1].detach().cpu().numpy())
        allcoord.append(Qspoint+tmpcoord[N:]-30)
        assert tmpcoord[N:].shape[0]==logit.shape[0]

    average_dict = {}
    for k in range(len(allcoord)):
        for i, element in enumerate(map(tuple, allcoord[k])):
            average_dict.setdefault(element, []).append(allpred[k][i])
    mean_dict = {k: np.mean(v, axis=0) for k, v in average_dict.items()}
    var_dict = {k: np.var(v, axis=0) for k, v in average_dict.items()}
    adddata = np.array([mean_dict[tuple(i)] for i in hvgcoord.tolist()])
    vardata = np.array([var_dict[tuple(i)] for i in hvgcoord.tolist()])

    average_dict = {}
    for k in range(len(allcoord)):
        for i, element in enumerate(map(tuple, allcoord[k])):
            average_dict.setdefault(element, []).append(allemb[k][i])
    average_dict = {k: np.mean(v, axis=0) for k, v in average_dict.items()}
    newemb = np.array([average_dict[tuple(i)] for i in hvgcoord.tolist()]) 

    adatahvg.layers['impute']=adddata
    adatahvg.layers['var']=vardata
    adatahvg.obsm['embedding']=newemb
    return adatahvg

from collections import deque
def expandTarget_quantize(adatahvg, model, valcoord, cfg,roundsize=None, thres=100,verbose=True,mode='meansoftmax'):
    """
    Expands the target data by estimating new points based on the given input data.

    Parameters:
    - adatahvg (AnnData): Annotated data object containing the input data.
    - model: The model used for estimation.
    - valcoord: The coordinates of the target points to be estimated.
    - cfg: Configuration object.
    - thres (int): Threshold value for the number of points.
    - verbose (bool): Whether to display additional information and plots.

    Returns:
    - predadata (AnnData): Annotated data object containing the expanded target data.
    """
    model.eval()
    hvgcoord = adatahvg.obs[['array_col', 'array_row']].values.astype(float)
    rd_set = set(map(tuple,valcoord)).difference(set(map(tuple,hvgcoord)))
    if len(rd_set)==0:
        logger.info('No new points')
        return adatahvg,0,0
    rd = np.array(list(rd_set))

    includedidx = []
    touchidx = []
    nbordict = {}
    if roundsize is None:
        roundsize=(cfg.round0,cfg.round1)
        logger.info(f'Using default roundsize {roundsize}')
    round0 = roundsize[0] #cfg.round0 
    round1 = roundsize[1] #cfg.round1
    from tqdm import tqdm
    codebook = np.load(cfg.codebook)
    basedir = cfg.codebook.split('codebook')[0]
    pca = pd.read_csv(f'{basedir}meta_cells_pca.csv',index_col=0).values
    from sklearn.neighbors import NearestNeighbors
    nbrs = NearestNeighbors(n_neighbors=1, algorithm='auto').fit(pca)
    pcaproj = np.load(f'{basedir}PCs.npy')
    mean = np.load(f'{basedir}mean.npy')
    codebook = ((nbrs,mean,pcaproj),codebook)

    cell_ds = SpatialTarget(adatahvg,cfg=cfg,codebook=codebook)
    inputthres = thres
    
    while len(includedidx)==0:
        for i, row in enumerate(hvgcoord):
            mask = (rd[:,0]<(row[0]+round0))&(rd[:,0]>(row[0]-round0))&(rd[:,1]<(row[1]+round1))&(rd[:,1]>(row[1]-round1))
            if np.sum(mask)>0:
                if cell_ds[i][1].size(0)>thres:
                    includedidx.append(i)
                    mask = (rd[:,0]<(row[0]+cfg.round0))&(rd[:,0]>(row[0]-cfg.round0))&(rd[:,1]<(row[1]+cfg.round1))&(rd[:,1]>(row[1]-cfg.round1))
                    maskidx = np.nonzero(mask)[0].tolist()
                    touchidx += maskidx
                    nbordict[i] = rd[maskidx] 
        thres = thres //2
        if thres<10:
            logger.info(f'thres smaller than 10: {thres}')
            round0 = round0*2
            round1 = round1*2
            thres = inputthres
    touchidx = list(set(touchidx))
    touchpoint = rd[touchidx]
    if verbose:
        print(f'preview added points {len(touchpoint)} we need {len(includedidx)} orignal points to estimate new points.')
        plt.figure()
        plt.scatter(hvgcoord[:,0],-hvgcoord[:,1],s=1)
        plt.scatter(rd[:,0],-rd[:,1],s=1,c='r')
        plt.scatter(hvgcoord[includedidx,0],-hvgcoord[includedidx,1],s=1,c='green')
        plt.scatter(touchpoint[:,0],-touchpoint[:,1],s=1,c='black')
        plt.show()

    average_dict = {}
    averageemb_dict = {}
    for x in touchpoint:
        average_dict[tuple(x)] = deque()
        averageemb_dict[tuple(x)] = deque()

    for index in tqdm(includedidx): 
        Qspoint = hvgcoord[index]
        tmpdata = cell_ds[index]
        tmpcoord = tmpdata[1].cpu().numpy()
        
        appendcoord = nbordict[index] - Qspoint +30
        distances =  np.linalg.norm(tmpcoord - appendcoord.mean(0),axis=1)
        farpt = tmpcoord[np.argsort(distances)[-1]]
        tmporder =  np.argsort(np.linalg.norm(tmpcoord - farpt,axis=1))
        tmpdata = (tmpdata[0][tmporder],tmpdata[1][tmporder])
        tmpcoord = tmpdata[1].cpu().numpy()
        
        distances = np.linalg.norm(appendcoord - farpt, axis=1)
        nbordict[index] = nbordict[index][np.argsort(distances)]
        appendcoord = nbordict[index] - Qspoint +30

        blocksize = len(appendcoord)+len(tmpcoord)
        rawlen = tmpdata[0].shape[0]
        inputidx = torch.zeros([1,blocksize,tmpdata[0].shape[1]]).cuda()
        inputcoord = torch.zeros([1,blocksize,2]).cuda()
        inputmask = torch.ones([1,blocksize]).to(bool).cuda()
        inputidx[0,:rawlen,:]=tmpdata[0]
        inputcoord[0,:rawlen,:]=tmpdata[1]
        inputcoord[0,rawlen:,:]=torch.tensor(appendcoord)
        outputemb = torch.zeros([1,blocksize,cfg.n_embd]).cuda()
        outputlogit = torch.zeros([1,blocksize,model.codebook.weight.shape[0]]).cuda()

        with torch.no_grad():
            for i in range(len(appendcoord)):
                logit, tmpemb = model(inputs_embeds=inputidx,coord=inputcoord,mask=inputmask,embedding=True)
                assert inputidx[0,rawlen+i].sum()==0
                nexttoken = model.query_index(logit[0,blocksize+rawlen+i-1]).item()
                outputemb[0,rawlen+i]=tmpemb[0,blocksize+rawlen+i-1]
                inputidx[0,rawlen+i]=nexttoken
                outputlogit[0,rawlen+i] = torch.nn.functional.softmax(model.query_index(logit[0,blocksize+rawlen+i-1],iflogit=True),dim=-1)
    
        tmpexppred = outputlogit[0,rawlen:,:].detach().cpu().numpy().tolist()
        tmpembpred = outputemb[0,rawlen:].detach().cpu().numpy().tolist()
        
        for i, element in enumerate(map(tuple, nbordict[index])):
            average_dict[tuple(element)].append(tmpexppred[i])
            averageemb_dict[tuple(element)].append(tmpembpred[i])

    if mode =='meansoftmax':
        mean_dict = {k: model.codebook(torch.tensor(np.array(v)).mean(0).argmax(dim=-1).to(model.codebook.weight.device)).cpu().numpy() for k, v in average_dict.items()}
    elif mode == 'mode':
        mean_dict = {k: model.codebook(torch.tensor(scipy.stats.mode(np.argmax(np.array(v),1)).mode).to(model.codebook.weight.device)).cpu().numpy() for k, v in average_dict.items()}
    elif mode == 'weighted':
        mean_dict = {k: torch.matmul(torch.tensor(np.array(v)).mean(0).to(model.codebook.weight.device).to(model.codebook.weight.dtype),model.codebook.weight.data).cpu().numpy() for k, v in average_dict.items()}
    var_dict = {k: np.mean(v,axis=0) for k, v in average_dict.items()}
    count_dict = {k: len(v) for k, v in average_dict.items()}
    adddata = np.array([mean_dict[tuple(i)] for i in touchpoint.tolist()])
    vardata = np.array([var_dict[tuple(i)] for i in touchpoint.tolist()])
    countdata = np.array([count_dict[tuple(i)] for i in touchpoint.tolist()])

    averageemb_dict = {k: np.mean(v, axis=0) for k, v in averageemb_dict.items()}
    newemb = np.array([averageemb_dict[tuple(i)] for i in touchpoint.tolist()]) 

    if 'embedding' not in adatahvg.obsm.keys():
        rawemb = np.zeros([adatahvg.shape[0],cfg.n_embd])
        logger.warning('Cannot find raw data embedding. Please generate embeddings by using the imputation funciton! Initialized with Zero')
    else:
        rawemb = adatahvg.obsm['embedding']
    if newemb.ndim <2:
        newemb = np.expand_dims(newemb,0)
    updateemb = np.concatenate([rawemb,newemb],axis=0)

    if scipy.sparse.issparse(adatahvg.X):
        rawdata = adatahvg.X.toarray()
    else:
        rawdata = adatahvg.X

    if 'predcount' in adatahvg.obs.columns.tolist():
        rawpcdata = adatahvg.obs['predcount'].values
    else:
        rawpcdata = np.zeros([rawdata.shape[0]])

    if newemb.ndim <2:
        vardata = np.expand_dims(vardata,0)
        adddata = np.expand_dims(adddata,0)
    

    updatepcdata = np.concatenate([rawpcdata,countdata])
    updatedata = np.concatenate([rawdata,adddata],axis=0)
    updatecoord = np.concatenate([hvgcoord,touchpoint],axis=0)
    predexpdf = pd.DataFrame(updatedata,columns=adatahvg.var.index.tolist())


    ori_meta = adatahvg.obs
    metadf = pd.DataFrame('',index=range(updatecoord.shape[0]),columns=ori_meta.columns)
    metadf.iloc[:ori_meta.shape[0]] = ori_meta
    metadf['array_col'] = updatecoord[:,0]
    metadf['array_row'] = updatecoord[:,1]
    metadf['predcount'] = updatepcdata
    if 'status' not in metadf:
        metadf['status']='0'
    status_idx = metadf.columns.tolist().index('status')
    newid = np.max(np.unique(metadf.iloc[:hvgcoord.shape[0],status_idx]).astype(int))+1
    metadf.iloc[hvgcoord.shape[0]:,status_idx]=str(newid)

    predadata = sc.AnnData(predexpdf,obs=metadf,var=adatahvg.var)
    predadata.obsm['spatial']=updatecoord
    predadata.obsm['embedding']=updateemb
    return predadata,vardata,touchpoint

import pathlib

if __name__ == "__main__":
    # specify adata path
    savedir = './tmp/'
    alldataname = ['Zhuang-ABCA-2']

    codebook = np.load('./metacell/codebook3000.npy')
    pca = pd.read_csv('.//metacell/meta_cells_pca.csv',index_col=0).values
    from sklearn.neighbors import NearestNeighbors
    nbrs = NearestNeighbors(n_neighbors=1).fit(pca)
    pcaproj = np.load('./metacell/PCs.npy')
    mean = np.load('./metacell/mean.npy')
    codebook = ((nbrs,mean,pcaproj),codebook)


    config_file = pathlib.Path('./PATH/hparams.yaml')
    class setting( object ):
        pass
    cfg=setting()
    if config_file.exists():
        with config_file.open('r') as f:
            d = yaml.unsafe_load(f)
            for k,v in d.items():
                setattr(cfg, k, v)
    cfg.ckpt_path = './dir/PATH/ckpt/ckpt_best.pt'

    torch.cuda.init()
    backend = cfg.backend
    compile = cfg.compile # Default True, use PyTorch 2.0 to compile the model to be faster
    gradient_accumulation_steps = cfg.gradient_accumulation_steps
    batch_size = cfg.batch_size
    block_size = cfg.block_size
    task = cfg.task
        
    n_layer = cfg.n_layer
    n_head = cfg.n_head
    n_embd = cfg.n_embd
    bias = cfg.bias
    dropout = cfg.dropout
    train_mode = cfg.train_mode
    init_from = cfg.init_from
    device_set = cfg.device
    dtype = cfg.dtype
    data_path = cfg.data_path
    ckpt_path = cfg.ckpt_path
    infersave_path = cfg.infersave_path
    N = cfg.N
    vocab_size = cfg.vocab_size
    torch.manual_seed(19491001)
    ckpt = torch.load(ckpt_path)
    print(f'load cfg.ckpt_path:{cfg.ckpt_path}')
    torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
    device_type = 'cuda'

    # model init
    from model import DaoConfig, GeST
    model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=4096,batch_size=batch_size,
                        bias=bias,dropout=dropout,train_mode=train_mode,task=task,vocab_size=vocab_size,loss_len=cfg.loss_len,
                        encoder = cfg.encoder, decoder = cfg.decoder,
                        skipconnect=cfg.skipconnect, noise = cfg.noise, rope_base = cfg.rope_base,loc_emb = cfg.loc_emb,device_type=cfg.device,modeltype=cfg.model,
                    codebook = cfg.codebook) # start with model_args from command line
    gptconf = DaoConfig(**model_args)
    model = GeST(gptconf)
    model.load_state_dict(ckpt['model'])
    model.to(device_type)
    model.eval()


    for adataname in alldataname:
        adataname = adataname[:-5]
        refadata = sc.read_h5ad(f'{savedir}{adataname}_ref.h5ad')
        valadata = sc.read_h5ad(f'{savedir}{adataname}_query.h5ad')
        valcoord = valadata.obs[['array_col','array_row']].values
        model.eval()
        expandadata,vardata,_ = expandTarget_quantize(refadata,model,valcoord,cfg,roundsize=(0.01,0.01),thres=50,verbose=False,mode='weighted')
        for i in range(10):
            expandadata,vardata,_ = expandTarget_quantize(expandadata,model,valcoord,cfg,roundsize=(0.01,0.01),thres=50,verbose=False,mode='weighted')
        
        valcoord = valadata.obs[['array_col', 'array_row']].values.astype(float)
        valcoordlist = [tuple(row) for row in valcoord]
        expandcoord = expandadata.obs[['array_col', 'array_row']].values.astype(float)
        expandcoorddict = {tuple(coord): idx for idx, coord in enumerate(expandcoord)}
        cidx = [expandcoorddict[c] for c in valcoordlist]
        predval = expandadata[cidx,:].copy()
        predval.write_h5ad(f'./tmp/val_pred_weighted_{adataname}.h5ad')