import argparse
import numpy as np
import anndata as ad
import pandas as pd
import rdata as rd
from scipy.sparse import csr_matrix
import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image
import io
import umap
import glob
import ast
import torch
import gc
import pdb

from gmmot.utils.config_tools import get_paths
from gmmot.utils.data_tools import scLoader, load_param_file
from gmmot.vae.main import VAE
from gmmot.omt.solver import omt
from gmmot.omt.functions import sample_gmm_kdim, geometric_mean, points_transport, mmd_rbf, compute_mmd, points_transport_vectorized


parser = argparse.ArgumentParser()
parser.add_argument("--y_s", default=0, type=int, help="Source domain index")
parser.add_argument("--y_t", default=1, type=int, help="Target domain index")
parser.add_argument("--Ks", default=0, type=int, help="Number of clusters in source domain")
parser.add_argument("--Kt", default=0, type=int, help="Number of clusters in target domain")
parser.add_argument("--cov_type", default='diag', type=str, help="covariance type, either 'diag' or 'full'")
parser.add_argument("--geometry", default='linear', type=str, help="geometry type, either linear or geodesic")
parser.add_argument("--timepoints", default=100, type=int, help="number of timepoints for the transportation")
parser.add_argument("--type_level", default='CDM_supertype_name', type=str, help="set the level in the taxonomy for GMM clustering")
parser.add_argument("--annot_level", default='CDM_subclass_name', type=str, help="set the level in the taxonomy for annotation")
parser.add_argument("--variational", default=False, action="store_true", help="enable variational mode")
parser.add_argument("--n_run", default=1, type=int, help="index of the run")
parser.add_argument("--toml_file", default='pyproject.toml', type=str, help="path to the toml file")
parser.add_argument("--CCF_level", default='', type=str, help="CCF level 1 annotation")
parser.add_argument("--verbose", default=False, action="store_true", help="verbose for omt solver")
parser.add_argument("--cuda", default=False, action="store_true", help="gpu device, use None for cpu")
parser.add_argument("--data_file", default='mouse_ageing_file', type=str, help="data file")
parser.add_argument("--device_num", default=0, type=int, help="cuda device number")


def main(args):
    
    ccfs = ['CTXsp', 'fiber tracts', 'HIP', 'HY', 'Isocortex', 'MB', 'OLF', 'PAL', 'RHP', 'sAMY' , 'STRd', 'TH', 'VS',]
    ccfs_colors = ['#393b79','#5254a3', '#6b6ecf', '#637939', '#8ca252', '#b5cf6b', '#cedb9c', '#8c6d31', 
          '#bd9e39', '#e7ba52', '#843c39', '#ad494a', '#d6616b',]
    
    # ccfs = ccfs[:5]
    # ccfs_colors = ccfs_colors[:5]
    rank = 0
    if args.cuda==False:
        device = torch.device("cpu")
    else:
        free_gpus = []
        for i in range(torch.cuda.device_count()):
            if torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated(i) > 0:
                free_gpus.append(i)
        if free_gpus:
            if args.ws == 1:
                device = torch.device(f"cuda:{free_gpus[args.device_num]}")
            else:
                device = torch.device(f"cuda:{free_gpus[rank]}")
            print('---> Using GPU(s): ' + torch.cuda.get_device_name(device))
        else:
            raise RuntimeError("No free GPU devices available.")

    dataset = 'ageing'
    config = get_paths(toml_file=args.toml_file, sub_file='files')
    data_file = config['paths']['data_path'] / dataset / config['files']['mouse_ageing_file']
    cirro_file = config['paths']['data_path'] / dataset / config['files']['cirro_corr'] 

    adata_ = ad.read_h5ad(data_file, backed='r+')
    genes = np.array(adata_.var.index)
    n_gene = len(genes)
    
    # if args.CCF_level == 'all':
    #     ccfs = [str(val) for val in adata_.obs['clean_CCF_level1'].unique()]
    # else:
    #     ccfs = [args.CCF_level]
    
    adata_.file.close()
    del adata_
    
    tr_cost = np.zeros(len(ccfs))
    
    for iccf, ccf in enumerate(ccfs):
        adata_ = ad.read_h5ad(data_file, backed='r+')
        adata = adata_[adata_.obs['clean_CCF_level1'].isin(set([ccf])).values].to_memory()
        # if args.CCF_level2:
        #     adata = adata[adata.obs['clean_CCF_level2'].isin(set([args.CCF_level2])).values].to_memory()
                
        list_ages = list(adata.obs['age'].unique())
        
        adata_.file.close()
        del adata_
        cirro_corrd = pd.read_csv(cirro_file, index_col=0)
        
        # Merge sort_order into df1
        cirro_corrd = cirro_corrd.merge(adata.obs, on='cell_id')
        x_corrd, y_corrd = cirro_corrd['cirro_x'], cirro_corrd['cirro_y']

        print(f'Number of cells in {ccf}: {adata.X.shape[0]}')
        print(f'Number of genes: {adata.X.shape[1]}')
        print("ages: " + ", ".join([f"{ag} ({sum(adata.obs['age']==ag)})" for ag in list_ages]))

        # First check the unique values in the age column
        print("Original unique age values:", adata.obs['age'].unique())

        # Create a mapping dictionary for the age values
        age_mapping = {'P150': 0, 'P540': 1, 'P720': 2}
        reverse_age_mapping = {v: k for k, v in age_mapping.items()}

        # Convert the age column using the mapping
        adata.obs['age_numeric'] = adata.obs['age'].map(age_mapping)
        
        results_path = config['paths']['main_dir'] / config['paths']['saving_path'] / 'mouse' 
        available_models = glob.glob(str(results_path) + f'/run_{args.n_run}_*nepoch_5000_all*')
        if not available_models:
            raise FileNotFoundError(f"No models found.")

        selected_model_file = available_models[0]
        print(selected_model_file)
        trained_models = glob.glob(selected_model_file + '/model/VAE*')
        # Get the model parameters from the selected model
        param_file = glob.glob(selected_model_file + '/param*')[0]
        
        loaded_param = load_param_file(param_file)
        params = {}
        for pp in loaded_param:
            key = pp.split(":")[0]
            ind_ = pp.find(":")
            try:
                params[key] = float(pp[ind_ + 2:])
                if params[key] == int(params[key]):
                    params[key] = int(params[key])
            except ValueError:
                params[key] = pp[ind_ + 2:]
            
            if isinstance(params[key], str):
                try:
                    params[key] = ast.literal_eval(params[key])
                except (ValueError, SyntaxError):
                    pass
        print(params)
        
        print('preparing the data loaders')
        train_loader, test_loader, data_loader = scLoader(
                                                        adata=adata, 
                                                        features=range(n_gene),
                                                        batch_size=512,
                                                        )
                
        vae = VAE(saving_folder=selected_model_file, device=device)
        vae.init_nn(
                    input_dim=n_gene, 
                    fc_dim=params["fc_dim"], 
                    lowD_dim=params["latent_dim"], 
                    n_layer=params["n_layer"],
                    x_drop=params["p_drop"], 
                    variational=params["variational"],
                    )
        
        vae.load_model(trained_models[0])
        
        mask = (adata.obs['age'] == list_ages[args.y_s]) 
        type_s = adata.obs.loc[mask, args.type_level]

        if args.Ks == 0:
            subclass_counts = adata.obs.loc[mask, 'CDM_supertype_name'].value_counts()
            Ks = int(sum(subclass_counts > 0) * 1.5)
        else:
            Ks = args.Ks
        
        mask = (adata.obs['age'] == list_ages[args.y_t]) 
        type_t = adata.obs.loc[mask, args.type_level]
        
        if args.Kt == 0:
            subclass_counts = adata.obs.loc[mask, 'CDM_supertype_name'].value_counts()
            Kt = int(sum(subclass_counts > 0) * 1.5)
        else:
            Kt = args.Kt

        x_s, z_s, z_t, _, _, solver_dict = vae.transfer( 
                                                        adata=adata, 
                                                        y_s=args.y_s, 
                                                        y_t=args.y_t, 
                                                        Ks=Ks, 
                                                        Kt=Kt, 
                                                        eps_gs=0.1, 
                                                        eps_w=0.1, 
                                                        alg='cvx', 
                                                        cov_type=args.cov_type, 
                                                        max_iter=10000, 
                                                        stop_thr=1e-10, 
                                                        verbose=args.verbose,
                                                        geometry=args.geometry,
                                                        timepoints=args.timepoints,
                                                        transport=False
                                                        )
        
        tr_cost[iccf] = solver_dict['total_cost']
        
    

    plt.bar(ccfs, tr_cost, color=ccfs_colors)

    plt.xticks(rotation=90, ha='right')
    plt.xlabel('CCF Level 1')
    plt.ylabel('Transport Cost')
    plt.title('Transport Cost per region')
    plt.tight_layout()
    plt.savefig(selected_model_file + f'/transport_cost_per_region_{args.y_s}_{args.y_t}.png', bbox_inches='tight', dpi=300)
    plt.close()


if __name__ == "__main__":
    args = parser.parse_args()
    main(args)