import scanpy as sc
import numpy as np
import argparse
import faiss
import warnings
import scanpy as sc
import time
import matplotlib.pyplot as plt
from sklearn import metrics
import seaborn as sns
import random
import numpy as np
from ...retreival_utils.faiss_retreival import similarity_search
import pandas as pd
import os
from sklearn.preprocessing import StandardScaler
warnings.filterwarnings('ignore')

N_COMPONENTS = 128

parser = argparse.ArgumentParser(description='Process single-cell data.')
parser.add_argument('--input_adata', type=str, required=True, help='Path to input .h5ad file')
parser.add_argument('--input_embeddings', type=str, required=True, help='Path to input .npy file')
parser.add_argument('--method', type=str, required=True, help='Method used for embeddings; all if you want to run all methods')
parser.add_argument('--retrieved_for_each_cell', type=int, required=True, help='Number of cells to retrieve for each cell')
parser.add_argument('--faiss_search', type=str, required=True, help='Faiss search method')
parser.add_argument('--obs', type=str, required=True)
parser.add_argument('--query_indexes', type=str, required=True)
parser.add_argument('--target_indexes', type=str, required=True)
parser.add_argument('--output_dir', type=str, required=True)

args = parser.parse_args()

adata = sc.read(args.input_adata)
obs_list = args.obs.split(",")

print(f"########### {args.method.capitalize()} ###########")
embeddings = np.load(args.input_embeddings)
scaler = StandardScaler()
embeddings = scaler.fit_transform(embeddings.T)
embeddings = embeddings.T
print("Embeddings shape: ", embeddings.shape)
query_indexes = np.loadtxt(args.query_indexes,dtype=int)
target_indexes = np.loadtxt(args.target_indexes,dtype=int)

distances, index = similarity_search(args, embeddings, query_indexes, target_indexes)

columns = ['Query'] + ['Result-' + str(i) for i in range(1,args.retrieved_for_each_cell+1)]
df = pd.DataFrame(np.concatenate([query_indexes.reshape(-1,1),target_indexes[index]],axis=1),columns=columns)
os.makedirs(args.output_dir,exist_ok=True)
df.to_csv(args.output_dir + "/index.csv")
for i in obs_list:
    target = pd.DataFrame(np.array(adata.obs[i])[df.values],columns = columns)
    target.to_csv(args.output_dir + "/" + i + ".csv")