import csv
import os
from typing import Optional
from anndata import AnnData
import numpy as np
import scanpy as sc
from scipy.sparse import spmatrix
from torch.utils.data import Dataset
from .scdata_util import pp_adata


class PerturbDataset(Dataset):
    def __init__(
        self,
        file_path: str,
        data_idx: int = 0,
        perturb_key: Optional[str] = None,
        preprocess: bool = True,
        gene_perturbation: bool = False,
        known_intv: bool = False,
        hard_intv: bool = False,
        test_pert_ratio: float = 0.0,
        load_test=False,
        load_path: Optional[str] = None,
        save_path: Optional[str] = None,
        **kwargs
    ):
        self.file_path = file_path
        self.data_idx = data_idx
        self.perturb_key = perturb_key
        self.preprocess = preprocess
        self.gene_perturbation = gene_perturbation
        self.known_intv = known_intv
        self.hard_intv = hard_intv
        self.test_pert_ratio = test_pert_ratio
        self.load_test = load_test
        self.load_path = load_path
        self.save_path = save_path
        self.n_gene = 0
        self.n_intv = 0
        
        self.data = None
        self.regimes = None
        self.masks = None
        self.perturbations = None

        self.perturb_features = None
        self.splitted = False

        if self.load_path is None:
            adata = self._setup_anndata(**kwargs)
            self.n_gene = adata.n_vars
            
            data = adata.X
            if isinstance(data, spmatrix):
                data = data.toarray()

            # Filter data
            is_train_set = adata.obs['train_set'].to_numpy()
            if load_test:
                self.perturb_features = self.perturb_features[~is_train_set]
                self.data = data[~is_train_set]
            else:
                self.perturb_features = self.perturb_features[is_train_set]
                self.data = data[is_train_set] 
            # Add perturbation features
            self.data = np.concatenate([self.data, self.perturb_features], axis=1)
            self._save_data(adata)
        else:
            self.data, self.perturb_features, self.masks, self.regimes = self._load_data()
            self.n_gene = self.data.shape[1]
            if self.perturb_features is not None:
                self.data = np.concatenate([self.data, self.perturb_features], axis=1)
        
        

    def _random_split(self, adata: AnnData):
        print('Splitting data.')
        n_test = int(adata.n_obs * self.test_pert_ratio)
        test_idx = np.random.choice(adata.n_obs, n_test, replace=False)
        train_set = np.ones(adata.n_obs, dtype=bool)
        train_set[test_idx] = False
        adata.obs['train_set'] = train_set
        self.splitted = True
    
    def _make_pert_feature(self, adata: AnnData):
        print('Prepare perturbation features.')
        if self.perturb_key is None:
            return

        pert_label = adata.obs[self.perturb_key].to_numpy()
        perturbations = adata.uns['perturbations']

        # encode perturbations as one-hot
        pert_features = np.zeros((adata.n_obs, len(perturbations)), dtype=int)
        for i, p in enumerate(perturbations):
            pert_features[pert_label == p, i] = 1.0

        self.perturb_features = pert_features
    
    def _add_mask_regime(self, adata: AnnData):
        print('Assign intervention mask and regime.')
        n = adata.n_obs
        self.regimes = np.zeros((n,))
        self.masks = [[] for i in range(n)]
        if self.perturb_key is None:
            return
        
        # Get targets of all perturbations
        pert_label = adata.obs[self.perturb_key]
        target_map = []
        group_ind = []
        for i, p in enumerate(self.perturbations):
            target_map.append(p.split(','))
            group_ind.append(np.where(pert_label == p)[0])
            self.regimes[np.where(pert_label == p)[0]] = i
        
        if self.gene_perturbation:
            # Find indices of perturbed genes
            gene_idx_map = {}
            for targets in target_map:
                for target in targets:
                    gene_idx = np.where(adata.var_names == target)[0]
                    if len(gene_idx) > 0:
                        gene_idx_map[target] = gene_idx[0]
            # Add masks
            for group_idx, ind in enumerate(group_ind):
                for i in ind:
                    self.masks[i] = [gene_idx_map[g] for g in target_map[group_idx]]

    def _save_data(self, adata: AnnData):
        if self.save_path is None:
            return
        # make data directory
        print(f'Saving data to {self.save_path}.')
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path, exist_ok=True)

        # save data
        if self.n_intv == 0:
            data_path = os.path.join(self.save_path, f"data{self.data_idx}.npy")
            np.save(data_path, self.data)
        else:
            data_path = os.path.join(self.save_path, f"data_interv{self.data_idx}.npy")
            np.save(data_path, self.data)
            data_perturb_path = os.path.join(self.save_path, f"perturb_features{self.data_idx}.npy")
            np.save(data_perturb_path, self.perturb_features)

            data_path = os.path.join(self.save_path, f"intervention{self.data_idx}.csv")
            with open(data_path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerows(self.masks)

        # save regimes
        if self.regimes is not None:
            regime_path = os.path.join(self.save_path, f"regime{self.data_idx}.csv")
            with open(regime_path, "w", newline="") as f:
                writer = csv.writer(f)
                for regime in self.regimes:
                    writer.writerow([regime])
        
        # train/test split
        if self.splitted:
            np.save(os.path.join(self.save_path, 'train_set.npy'), adata.obs['train_set'].to_numpy())
        
        # Other class attributes
        np.save(os.path.join(self.save_path, 'perturbations.npy'), self.perturbations)
        hparams = {}
        for attr, val in self.__dict__.items():
            if isinstance(val, (int, float, str)):
                hparams[attr] = val
        np.save(os.path.join(self.save_path, 'hparams.npy'), hparams)
    
    def _load_data(self):
        """
        Load the mask, regimes, and data
        """
        # Load attributes
        self.perturbations = np.load(os.path.join(self.load_path, 'perturbations.npy'), allow_pickle=True)
        hparams = np.load(os.path.join(self.load_path, 'hparams.npy'), allow_pickle=True).item()
        for attr, val in hparams.items():
            setattr(self, attr, val)

        if self.n_intv > 0:
            name_data = f"data_interv{self.data_idx}.npy"
            name_perturb = f"perturb_features{self.data_idx}.npy"
        else:
            name_data = f"data{self.data_idx}.npy"
            name_perturb = None

        # Load data
        self.data_path = os.path.join(self.load_path, name_data)
        data = np.load(self.data_path, allow_pickle=True)
        perturb_features = None
        if self.n_intv > 0:
            self.perturb_path = os.path.join(self.load_path, name_perturb)
            perturb_features = np.load(self.perturb_path, allow_pickle=True)

        # Load intervention masks and regimes
        masks = []
        if self.n_intv > 0:
            name_data = f"data_interv{self.data_idx}.npy"
            interv_path = os.path.join(
                self.load_path, f"intervention{self.data_idx}.csv"
            )
            regimes = np.genfromtxt(
                os.path.join(self.load_path, f"regime{self.data_idx}.csv"),
                delimiter=",",
            )
            regimes = regimes.astype(int)

            # read masks
            with open(interv_path, "r") as f:
                interventions_csv = csv.reader(f)
                for row in interventions_csv:
                    mask = [int(x) for x in row]
                    masks.append(mask)
        else:
            regimes = np.array([0] * data.shape[0])
        
        # Filter out training/test set
        is_train_set = np.load(os.path.join(self.load_path, 'train_set.npy'))
        if self.load_test:
            data = data[~is_train_set]
            if self.n_intv > 0:
                perturb_features = perturb_features[~is_train_set]
            masks = [m for i, m in enumerate(masks) if not is_train_set[i]]
            regimes = regimes[~is_train_set]
        else:
            data = data[is_train_set]
            if self.n_intv > 0:
                perturb_features = perturb_features[is_train_set]
            masks = [m for i, m in enumerate(masks) if is_train_set[i]]
            regimes = regimes[is_train_set]

        return data, perturb_features, masks, regimes
    
    def _setup_anndata(self, **kwargs) -> AnnData:
        print('Setting up AnnData.')
        adata = sc.read_h5ad(self.file_path)
        self.splitted = 'train_set' in adata.obs
        # Preprocessing
        if self.preprocess:
            pp_adata(adata, self.perturb_key, self.gene_perturbation, **kwargs)

        # Random split and parse perturbation labels
        if not self.splitted:
            if self.perturb_key is not None:
                pert_label = adata.obs[self.perturb_key].to_numpy()
                perturbations = np.unique(pert_label)
                adata.uns['perturbations'] = perturbations
            self._random_split(adata)
        
        if 'perturbations' in adata.uns:
            self.perturbations = adata.uns['perturbations']
            self.n_intv = len(self.perturbations)

        # Make dataset
        self._make_pert_feature(adata)   
        self._add_mask_regime(adata)

        return adata

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.known_intv and self.hard_intv:
            # binarize mask from list
            masks_list = self.masks[idx]
            masks = np.ones((self.n_gene,))
            for j in masks_list:
                masks[j] = 0
            return (
                self.data[idx].astype(np.float32),
                masks.astype(np.float32),
                self.regimes[idx],
            )
        else:
            return (
                self.data[idx].astype(np.float32),
                np.ones((self.n_gene,)).astype(np.float32),
                self.regimes[idx],
            )


