from torch.utils.data import Dataset, Subset
import numpy as np
import torch
from pathlib import Path
import pickle
import h5py
import pandas as pd
import scanpy as sc
from torchmetrics.functional.regression import spearman_corrcoef


def load_data(data_dir):
    data_paths = Path(data_dir).glob('*.h5')
    data_dict = {
        path.stem: h5py.File(path, 'r')
        for path in data_paths
    }
    return data_dict

def generate_datasets_all_samples(data_dict, val_ratio=0.1, test_ratio=0.1, **kwargs):
    total_dataset = IMG2STDataset(data_dict, **kwargs)
    test_num = int(len(total_dataset)*test_ratio)
    val_num = int(len(total_dataset)*val_ratio)
    train_num = len(total_dataset) - val_num - test_num
    test_dataset, val_dataset, train_dataset = torch.utils.data.random_split(total_dataset, [test_num, val_num, train_num])
    return test_dataset, val_dataset, train_dataset


def get_sample_indeces(total_dataset, sample_ids):
    sample_ids = [total_dataset.sample_id_map[sample_id] for sample_id in sample_ids]
    sample_indeces = torch.where(torch.isin(total_dataset.sample_ids, torch.tensor(sample_ids)))[0]
    return sample_indeces

def generate_datasets_sample_wise(data_dict, test_sample_ids, val_sample_ids, train_sample_ids=None, **kwargs):
    if val_sample_ids is None:
        val_sample_ids = []
        
    if train_sample_ids is None:
        train_sample_ids = [sample_id for sample_id in data_dict.keys() if sample_id not in test_sample_ids + val_sample_ids]
    else:
        train_sample_ids = train_sample_ids
    total_dataset = IMG2STDataset(data_dict, **kwargs)
    test_dataset = Subset(total_dataset, get_sample_indeces(total_dataset, test_sample_ids))
    val_dataset = Subset(total_dataset, get_sample_indeces(total_dataset, val_sample_ids))
    train_dataset = Subset(total_dataset, get_sample_indeces(total_dataset, train_sample_ids))
    return test_dataset, val_dataset, train_dataset


def get_common_genes(data_dict):
    sample_ids = list(data_dict.keys())
    geneset_list = [set(data_dict[sample_id]['gene_list'][:]) for sample_id in sample_ids]
    common_genes = list(set.intersection(*geneset_list))
    return common_genes
    

def idx_map2common_genes(data_dict, common_genes):
    idx_map_dict = {}
    for sample_id, ds in data_dict.items():
        gene_list = ds['gene_list'][:]
        gene_to_index = {gene: idx for idx, gene in enumerate(gene_list)}
        idx_map_dict[sample_id] = torch.tensor([gene_to_index[gene] for gene in common_genes], dtype=torch.int32)
    return idx_map_dict


def extract_high_variance_genes(data_dict, ntop_genes=50, mean=False):
    common_genes = np.array(get_common_genes(data_dict))
    idx_map = idx_map2common_genes(data_dict, common_genes)
    sample_ids = list(data_dict.keys())
    orig_exp = torch.cat([torch.tensor(data_dict[sample_id]['exp'][:]).float()[:, idx_map[sample_id]] for sample_id in sample_ids])
    if mean:
        gene_variances = torch.mean(orig_exp, dim=0)
    else:
        gene_variances = torch.var(orig_exp, dim=0)
    top_genes_idx = torch.argsort(gene_variances, descending=True)[:ntop_genes].int()
    return common_genes[top_genes_idx]

class IMG2STDataset(Dataset):
    def __init__(self, data_dict, transform=None, ntop_genes=250, use_gene_list=None):
        sample_ids = list(data_dict.keys())
        self.sample_id_map = {sample_id: i for i, sample_id in enumerate(sample_ids)}
        if 'img_feat' in list(list(data_dict.values())[0].keys()):
            self.img_feat = torch.cat([torch.tensor(data_dict[sample_id]['img_feat'][:]) for sample_id in sample_ids])
        else:
            self.img_feat = torch.cat([torch.tensor(data_dict[sample_id]['features'][:]) for sample_id in sample_ids])
            
        self.common_genes = get_common_genes(data_dict)
        self.idx_map = idx_map2common_genes(data_dict, self.common_genes)
        self.count = torch.cat([torch.tensor(data_dict[sample_id]['count'][:]).float()[:, self.idx_map[sample_id]] for sample_id in sample_ids])
        if use_gene_list is None:
            orig_exp = torch.cat([torch.tensor(data_dict[sample_id]['exp'][:]).float()[:, self.idx_map[sample_id]] for sample_id in sample_ids])
            gene_variances = torch.var(orig_exp, dim=0)
            use_gene_idx_list = torch.argsort(gene_variances, descending=True)[:ntop_genes]
        else:
            common_genes = list(np.array(self.common_genes).astype(str))
            use_gene_idx_list = torch.tensor([common_genes.index(gene) for gene in use_gene_list])
        self.count = self.count[:, use_gene_idx_list]
        self.exp = torch.log(1.0e4 * self.count / self.count.sum(dim=1, keepdims=True) + 1.0)
        self.common_genes = [list(self.common_genes)[i] for i in use_gene_idx_list]
        self.sample_ids = torch.cat([torch.tensor([self.sample_id_map[sample_id]]*len(data_dict[sample_id]['exp'][:])) for sample_id in sample_ids])

    def __len__(self):
        return len(self.img_feat)

    def __getitem__(self, idx):
        data = {
            'img_feat': self.img_feat[idx],
            'exp': self.exp[idx],
            'count': self.count[idx],
            'sample_id': self.sample_ids[idx]
            
        }
        return data