import argparse
import numpy as np
import anndata as ad
import pandas as pd
import rdata as rd
from scipy.sparse import csr_matrix
import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image
import io
import umap
import glob
import ast
import torch
import gc
import pdb

from gmmot.utils.config_tools import get_paths
from gmmot.utils.data_tools import scLoader, load_param_file
from gmmot.vae.main import VAE
from gmmot.omt.solver import omt
from gmmot.omt.functions import sample_gmm_kdim, geometric_mean, points_transport, mmd_rbf, compute_mmd, points_transport_vectorized


parser = argparse.ArgumentParser()
parser.add_argument("--latent_dim", default=16, type=int, help="latent dimension")
parser.add_argument("--y_s", default=0, type=int, help="Source domain index")
parser.add_argument("--y_t", default=1, type=int, help="Target domain index")
parser.add_argument("--Ks", default=0, type=int, help="Number of clusters in source domain")
parser.add_argument("--Kt", default=0, type=int, help="Number of clusters in target domain")
parser.add_argument("--n_sc", default=1, type=int, help="scaling factor for the number of GMM components")
parser.add_argument("--cov_type", default='full', type=str, help="covariance type, either 'diag' or 'full'")
parser.add_argument("--reg_covar", default=1e-5, type=float, help="regularization for covariance matrix")
parser.add_argument("--eps_gs", default=0.05, type=float, help="epsilon for source domain")
parser.add_argument("--eps_w", default=0.05, type=float, help="epsilon for weights")
parser.add_argument("--geometry", default='linear', type=str, help="geometry type, either linear or geodesic")
parser.add_argument("--max_iter", default=20000, type=int, help="maximum number of iterations for the solver")
parser.add_argument("--stop_thr", default=1e-10, type=float, help="stopping threshold for the solver")
parser.add_argument("--timepoints", default=10, type=int, help="number of timepoints for the transportation")
parser.add_argument("--type_level", default='cluster_label', type=str, help="set the level in the taxonomy for GMM clustering")
parser.add_argument("--annot_level", default='supertype_label', type=str, help="set the level in the taxonomy for annotation")
parser.add_argument("--variational", default=False, action="store_true", help="enable variational mode")
parser.add_argument("--n_run", default=1, type=int, help="index of the run")
parser.add_argument("--toml_file", default='pyproject.toml', type=str, help="path to the toml file")
parser.add_argument("--class_label", default='', type=str, help="transcriptomic class label, e.g., 'OPC' ")
parser.add_argument("--verbose", default=False, action="store_true", help="verbose for omt solver")
parser.add_argument("--cuda", default=False, action="store_true", help="gpu device, use None for cpu")
parser.add_argument("--data_file", default='mouse_ageing_file', type=str, help="data file")
parser.add_argument("--device_num", default=0, type=int, help="cuda device number")
parser.add_argument("--attempt", default=1, type=int, help="attempt number for the run")


def main(args):
    
    rank = 0
    if args.cuda==False:
        device = torch.device("cpu")
    else:
        free_gpus = []
        for i in range(torch.cuda.device_count()):
            if torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated(i) > 0:
                free_gpus.append(i)
        if free_gpus:
            if args.ws == 1:
                device = torch.device(f"cuda:{free_gpus[args.device_num]}")
            else:
                device = torch.device(f"cuda:{free_gpus[rank]}")
            print('---> Using GPU(s): ' + torch.cuda.get_device_name(device))
        else:
            raise RuntimeError("No free GPU devices available.")

    dataset = 'ageing'
    config = get_paths(toml_file=args.toml_file, sub_file='scrna')
    data_file_name = f'{args.class_label}_file' 
    data_file = config['paths']['data_path'] / dataset / config['scrna'][data_file_name]
    adata = ad.read_h5ad(data_file)
    
    adata.obs['cell_id'] = adata.obs.index
    adata.obs['age'] = adata.obs['age_cat'].astype('category')
    genes = adata.var.index.values
    n_gene = len(genes)
    print(f'Number of cells: {adata.X.shape[0]}')
    print(f'Number of genes: {adata.X.shape[1]}')

    list_ages = list(adata.obs['age_cat'].unique())

    print("ages: " + ", ".join([f"{ag} ({sum(adata.obs['age_cat']==ag)})" for ag in list_ages]))

    # First check the unique values in the age column
    print("Original unique age values:", adata.obs['age_cat'].unique())

    # Create a mapping dictionary for the age values
    age_mapping = {'adult': 0, 'aged': 1}
    reverse_age_mapping = {v: k for k, v in age_mapping.items()}

    # Convert the age column using the mapping
    adata.obs['age_numeric'] = adata.obs['age_cat'].map(age_mapping)

    results_path = config['paths']['main_dir'] / config['paths']['saving_path'] / 'mouse' 
    available_models = glob.glob(str(results_path) + f'/run_{args.n_run}_zDim_{args.latent_dim}*_{args.class_label}_*')
    if not available_models:
        raise FileNotFoundError(f"No models found.")

    selected_model_file = available_models[0]
    print(selected_model_file)
    trained_models = glob.glob(selected_model_file + '/model/VAE*')
    # Get the model parameters from the selected model
    param_file = glob.glob(selected_model_file + '/param*')[0]
    
    loaded_param = load_param_file(param_file)
    params = {}
    for pp in loaded_param:
        key = pp.split(":")[0]
        ind_ = pp.find(":")
        try:
            params[key] = float(pp[ind_ + 2:])
            if params[key] == int(params[key]):
                params[key] = int(params[key])
        except ValueError:
            params[key] = pp[ind_ + 2:]
        
        if isinstance(params[key], str):
            try:
                params[key] = ast.literal_eval(params[key])
            except (ValueError, SyntaxError):
                pass
        
    train_loader, test_loader, data_loader = scLoader(
                                                    adata=adata, 
                                                    features=range(n_gene),
                                                    batch_size=512,
                                                    )
    
    vae = VAE(saving_folder=selected_model_file, device=device)
    vae.init_nn(
                input_dim=n_gene, 
                network=params["network"],
                fc_dim=params["fc_dim"], 
                lowD_dim=params["latent_dim"], 
                n_layer=params["n_layer"],
                x_drop=0., 
                variational=params["variational"],
                )
    
    vae.load_model(trained_models[0])

    mask = (adata.obs['age_cat'] == list_ages[args.y_s]) 
    type_s = adata.obs.loc[mask, args.annot_level]

    if args.Ks == 0:
        subclass_counts = adata.obs.loc[mask, args.type_level].value_counts()
        Ks = int(sum(subclass_counts > 0) * args.n_sc)
    else:
        Ks = args.Ks
    
    mask = (adata.obs['age_cat'] == list_ages[args.y_t]) 
    type_t = adata.obs.loc[mask, args.type_level]
    
    if args.Kt == 0:
        subclass_counts = adata.obs.loc[mask, args.type_level].value_counts()
        Kt = int(sum(subclass_counts > 0) * args.n_sc)
    else:
        Kt = args.Kt
    
    print(f'Number of clusters in source domain: {Ks}')
    print(f'Number of clusters in target domain: {Kt}')

    x_s, z_s, z_t, z_tt, cost_df, solver_dict = vae.transfer( 
                                                        adata=adata, 
                                                        y_s=args.y_s, 
                                                        y_t=args.y_t, 
                                                        Ks=Ks, 
                                                        Kt=Kt, 
                                                        eps_gs=args.eps_gs, 
                                                        eps_w=args.eps_w, 
                                                        alg='cvx', 
                                                        cov_type=args.cov_type, 
                                                        reg_covar=args.reg_covar,
                                                        max_iter=args.max_iter, 
                                                        stop_thr=args.stop_thr, 
                                                        verbose=args.verbose,
                                                        geometry=args.geometry,
                                                        timepoints=args.timepoints,
                                                        variational=args.variational,
                                                        )
        
    z_t_predict = z_tt[-1, :, :]
    z_t_original = z_t[:, :params["latent_dim"]]
    z_s = z_s[:, :params["latent_dim"]]
    sc_target_predict = vae.generate(z_t_predict)
    sc_target_original = vae.generate(z_t_original)
    sc_source = vae.generate(z_s)
    n_s = sc_source.shape[0]
    n_t = sc_target_original.shape[0]
    
    cost_x = np.linalg.norm(sc_target_predict - x_s, axis=1)
    cost_df['cost_x'] = cost_x

    mask = (adata.obs['age_cat'] == list_ages[args.y_t])
    label_t = adata.obs.loc[mask, args.annot_level]
    anno_predict = vae.annotation(z_t_original, label_t, z_t_predict)
    
    x =  np.concatenate([sc_source, sc_target_original, sc_target_predict], axis=0)
    embedding_1 = umap.UMAP(n_neighbors=15, min_dist=0.5, random_state=10, metric='euclidean').fit_transform(x)

    xx_s = embedding_1[:n_s, :]
    xx_t = embedding_1[n_s:n_s+n_t, :]
    xx_tt = embedding_1[n_s+n_t:, :]
    
    fontsize = 6
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].scatter(xx_s[:, 0], xx_s[:, 1], s=1, color='blue', marker='o', alpha=0.5, label=reverse_age_mapping[args.y_s])
    ax[0].scatter(xx_t[:, 0], xx_t[:, 1], s=1, color='red', marker='^', alpha=0.5, label=reverse_age_mapping[args.y_t])
    ax[0].legend(fontsize=fontsize)
    ax[1].scatter(xx_s[:, 0], xx_s[:, 1], s=1, color='blue', marker='o', alpha=0.5, label=reverse_age_mapping[args.y_s])
    ax[1].scatter(xx_tt[:, 0], xx_tt[:, 1], s=1, color='orange', marker='s', alpha=0.5, label='Transported cells')
    ax[1].legend(fontsize=fontsize)
        
    # Plot arrows from blue to orange dots
    # for i in range(n_s):
    #     x_blue, y_blue = xx_s[i]
    #     x_orange, y_orange = xx_tt[i]

    #     # Calculate the change in x and y
    #     dx = x_orange - x_blue
    #     dy = y_orange - y_blue
        
    #     ax[1].arrow(x_blue, y_blue, dx, dy,
    #                 color='grey',
    #                 linewidth=0.5,
    #                 head_width=0.1,
    #                 head_length=0.15,
    #                 linestyle='--',
    #                 length_includes_head=True, # Ensure the arrow's length includes the head
    #                 alpha=0.7)

    fig.savefig(selected_model_file + f'/UMAP_x_{args.annot_level}_{Ks}_{Kt}.png', dpi=300, bbox_inches='tight')

    fig, ax = plt.subplots(figsize=(6, 5))
    all_types = np.unique(np.concatenate([type_s, type_t, np.array(n_s*['Transported'])]))
    colors = plt.cm.get_cmap('tab20', len(all_types))
    type_to_color = {cell_type: colors(i) for i, cell_type in enumerate(all_types)}

    unique_type_s = np.unique(type_s)
    for cell_type in unique_type_s:
        idx = type_s == cell_type
        ax.scatter(xx_s[idx, 0], xx_s[idx, 1], s=1, marker='o', color=type_to_color[cell_type], alpha=0.5, label=f'Source: {cell_type}')
        
    unique_type_t = np.unique(type_t)
    for cell_type in unique_type_t:
        idx = type_t == cell_type
        ax.scatter(xx_t[idx, 0], xx_t[idx, 1], s=1, marker='^', color=type_to_color[cell_type], alpha=0.5, label=f'Target: {cell_type}')
    
    ax.scatter(xx_tt[:, 0], xx_tt[:, 1], s=1, color=type_to_color['Transported'], marker='s', alpha=0.5, label='Transported cells')
    # ax.legend(title='Cell Types', bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0., fontsize=fontsize)
    # fig.tight_layout()
    fig.savefig(selected_model_file + f'/UMAP_x_{args.annot_level}_colored_{Ks}_{Kt}.png', dpi=300, bbox_inches='tight')

    z = np.concatenate([z_s, z_t_original, z_t_predict], axis=0)
    embedding_2 = umap.UMAP(n_neighbors=15, min_dist=0.5, random_state=10, metric='euclidean').fit_transform(z)
    
    xx_s = embedding_2[:n_s, :]
    xx_t = embedding_2[n_s:n_s+n_t, :]
    xx_tt = embedding_2[n_s+n_t:, :]
    
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].scatter(xx_s[:, 0], xx_s[:, 1], s=1, color='blue', marker='o', alpha=0.5, label=reverse_age_mapping[args.y_s])
    ax[0].scatter(xx_t[:, 0], xx_t[:, 1], s=1, color='red', marker='^', alpha=0.5, label=reverse_age_mapping[args.y_t])
    ax[0].legend(fontsize=fontsize)
    ax[1].scatter(xx_s[:, 0], xx_s[:, 1], s=1, color='blue', marker='o', alpha=0.5, label=reverse_age_mapping[args.y_s])
    ax[1].scatter(xx_tt[:, 0], xx_tt[:, 1], s=1, color='orange', marker='s', alpha=0.5, label='Transported cells')
    ax[1].legend(fontsize=fontsize)
    
    # Plot arrows from blue to orange dots
    # for i in range(n_s):
    #     x_blue, y_blue = xx_s[i]
    #     x_orange, y_orange = xx_tt[i]

    #     # Calculate the change in x and y
    #     dx = x_orange - x_blue
    #     dy = y_orange - y_blue
        
    #     ax[1].arrow(x_blue, y_blue, dx, dy,
    #                 color='grey',
    #                 linewidth=0.5,
    #                 head_width=0.1,
    #                 head_length=0.15,
    #                 linestyle='--',
    #                 length_includes_head=True, # Ensure the arrow's length includes the head
    #                 alpha=0.7)
        
    fig.savefig(selected_model_file + f'/UMAP_z_{args.annot_level}_{Ks}_{Kt}.png', dpi=300, bbox_inches='tight')

    cost_df.to_csv(selected_model_file + f'/costs_{args.annot_level}_{Ks}_{Kt}_run_{args.attempt}.csv', index=False)

    adata.obsm['vae'] = np.zeros((adata.shape[0], params["latent_dim"]))
    mask = (adata.obs['age_cat'] == list_ages[args.y_s]) 
    adata.obsm['vae'][mask, :] = z_s
    mask = (adata.obs['age_cat'] == list_ages[args.y_t])
    adata.obsm['vae'][mask, :] = z_t_original
    adata.write(config['paths']['data_path'] / dataset / f'cells_{args.class_label}_{args.annot_level}_{Ks}_{Kt}_run_{args.attempt}.h5ad')
    
    df_ = adata.obs.merge(cost_df, on='cell_id')
    df_[args.annot_level + '_aged'] = anno_predict
    adata_tmp = ad.AnnData(X=csr_matrix(sc_target_predict), obs=df_, var=adata.var)
    adata_tmp.obsm['vae'] = z_t_predict
    adata_tmp.write(config['paths']['data_path'] / dataset / f'transfer_cells_{args.class_label}_{args.annot_level}_{Ks}_{Kt}_run_{args.attempt}.h5ad')
    



if __name__ == "__main__":
    args = parser.parse_args()
    main(args)