import argparse
import numpy as np
import anndata as ad
import pandas as pd
import rdata as rd
from scipy.sparse import csr_matrix
import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image
import io
import umap
import glob
import ast
import torch
import gc
import pdb

from gmmot.utils.config_tools import get_paths
from gmmot.utils.data_tools import scLoader, load_param_file
from gmmot.vae.main import VAE
from gmmot.omt.solver import omt
from gmmot.omt.functions import sample_gmm_kdim, geometric_mean, points_transport, mmd_rbf, compute_mmd, points_transport_vectorized


parser = argparse.ArgumentParser()
parser.add_argument("--y_s", default=0, type=int, help="Source domain index")
parser.add_argument("--y_t", default=1, type=int, help="Target domain index")
parser.add_argument("--Ks", default=0, type=int, help="Number of clusters in source domain")
parser.add_argument("--Kt", default=0, type=int, help="Number of clusters in target domain")
parser.add_argument("--cov_type", default='diag', type=str, help="covariance type, either 'diag' or 'full'")
parser.add_argument("--geometry", default='linear', type=str, help="geometry type, either linear or geodesic")
parser.add_argument("--timepoints", default=5, type=int, help="number of timepoints for the transportation")
parser.add_argument("--type_level", default='CDM_supertype_name', type=str, help="set the level in the taxonomy for GMM clustering")
parser.add_argument("--annot_level", default='CDM_subclass_name', type=str, help="set the level in the taxonomy for annotation")
parser.add_argument("--variational", default=False, action="store_true", help="enable variational mode")
parser.add_argument("--n_run", default=1, type=int, help="index of the run")
parser.add_argument("--toml_file", default='pyproject.toml', type=str, help="path to the toml file")
parser.add_argument("--CCF_level", default='', type=str, help="CCF level 1 annotation")
parser.add_argument("--verbose", default=False, action="store_true", help="verbose for omt solver")
parser.add_argument("--cuda", default=False, action="store_true", help="gpu device, use None for cpu")
parser.add_argument("--data_file", default='mouse_ageing_file', type=str, help="data file")
parser.add_argument("--device_num", default=0, type=int, help="cuda device number")


def main(args):
    
    rank = 0
    if args.cuda==False:
        device = torch.device("cpu")
    else:
        free_gpus = []
        for i in range(torch.cuda.device_count()):
            if torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated(i) > 0:
                free_gpus.append(i)
        if free_gpus:
            if args.ws == 1:
                device = torch.device(f"cuda:{free_gpus[args.device_num]}")
            else:
                device = torch.device(f"cuda:{free_gpus[rank]}")
            print('---> Using GPU(s): ' + torch.cuda.get_device_name(device))
        else:
            raise RuntimeError("No free GPU devices available.")

    dataset = 'ageing'
    config = get_paths(toml_file=args.toml_file, sub_file='files')
    data_file = config['paths']['data_path'] / dataset / config['files']['mouse_ageing_file']
    cirro_file = config['paths']['data_path'] / dataset / config['files']['cirro_corr'] 

    adata_ = ad.read_h5ad(data_file, backed='r+')
    genes = np.array(adata_.var.index)
    n_gene = len(genes)
    
    if args.CCF_level == 'all':
        ccfs = [str(val) for val in adata_.obs['clean_CCF_level1'].unique()]
    else:
        ccfs = [args.CCF_level]
    
    adata_.file.close()
    del adata_
    
    for ccf in ccfs:
        adata_ = ad.read_h5ad(data_file, backed='r+')
        adata = adata_[adata_.obs['clean_CCF_level1'].isin(set([ccf])).values].to_memory()
        # if args.CCF_level2:
        #     adata = adata[adata.obs['clean_CCF_level2'].isin(set([args.CCF_level2])).values].to_memory()
                
        list_ages = list(adata.obs['age'].unique())
        
        adata_.file.close()
        del adata_
        cirro_corrd = pd.read_csv(cirro_file, index_col=0)
        
        # Merge sort_order into df1
        cirro_corrd = cirro_corrd.merge(adata.obs, on='cell_id')
        x_corrd, y_corrd = cirro_corrd['cirro_x'], cirro_corrd['cirro_y']

        print(f'Number of cells in {ccf}: {adata.X.shape[0]}')
        print(f'Number of genes: {adata.X.shape[1]}')
        print("ages: " + ", ".join([f"{ag} ({sum(adata.obs['age']==ag)})" for ag in list_ages]))

        # First check the unique values in the age column
        print("Original unique age values:", adata.obs['age'].unique())

        # Create a mapping dictionary for the age values
        age_mapping = {'P150': 0, 'P540': 1, 'P720': 2}
        reverse_age_mapping = {v: k for k, v in age_mapping.items()}

        # Convert the age column using the mapping
        adata.obs['age_numeric'] = adata.obs['age'].map(age_mapping)
        
        results_path = config['paths']['main_dir'] / config['paths']['saving_path'] / 'mouse' 
        available_models = glob.glob(str(results_path) + f'/run_{args.n_run}_*nepoch_5000_all*')
        if not available_models:
            raise FileNotFoundError(f"No models found.")

        selected_model_file = available_models[0]
        print(selected_model_file)
        trained_models = glob.glob(selected_model_file + '/model/VAE*')
        # Get the model parameters from the selected model
        param_file = glob.glob(selected_model_file + '/param*')[0]
        
        loaded_param = load_param_file(param_file)
        params = {}
        for pp in loaded_param:
            key = pp.split(":")[0]
            ind_ = pp.find(":")
            try:
                params[key] = float(pp[ind_ + 2:])
                if params[key] == int(params[key]):
                    params[key] = int(params[key])
            except ValueError:
                params[key] = pp[ind_ + 2:]
            
            if isinstance(params[key], str):
                try:
                    params[key] = ast.literal_eval(params[key])
                except (ValueError, SyntaxError):
                    pass
        print(params)
        
        print('preparing the data loaders')
        train_loader, test_loader, data_loader = scLoader(
                                                        adata=adata, 
                                                        features=range(n_gene),
                                                        batch_size=512,
                                                        )
                
        print('training the model')
        vae = VAE(saving_folder=selected_model_file, device=device)
        vae.init_nn(
                    input_dim=n_gene, 
                    fc_dim=params["fc_dim"], 
                    lowD_dim=params["latent_dim"], 
                    n_layer=params["n_layer"],
                    x_drop=params["p_drop"], 
                    variational=params["variational"],
                    )
        vae.load_model(trained_models[0])
        
        mask = (adata.obs['age'] == list_ages[args.y_s]) 
        type_s = adata.obs.loc[mask, args.type_level]

        if args.Ks == 0:
            subclass_counts = adata.obs.loc[mask, 'CDM_supertype_name'].value_counts()
            Ks = int(sum(subclass_counts > 0) * 1.5)
        else:
            Ks = args.Ks
        
        mask = (adata.obs['age'] == list_ages[args.y_t]) 
        type_t = adata.obs.loc[mask, args.type_level]
        
        if args.Kt == 0:
            subclass_counts = adata.obs.loc[mask, 'CDM_supertype_name'].value_counts()
            Kt = int(sum(subclass_counts > 0) * 1.5)
        else:
            Kt = args.Kt

        x_s, z_s, z_t, z_tt, cost_df, solver_dict = vae.transfer( 
                                                            adata=adata, 
                                                            y_s=args.y_s, 
                                                            y_t=args.y_t, 
                                                            Ks=Ks, 
                                                            Kt=Kt, 
                                                            eps_gs=0.1, 
                                                            eps_w=0.1, 
                                                            alg='cvx', 
                                                            cov_type=args.cov_type, 
                                                            max_iter=10000, 
                                                            stop_thr=1e-10, 
                                                            verbose=args.verbose,
                                                            geometry=args.geometry,
                                                            timepoints=args.timepoints,
                                                            )
        
        z_t_predict = z_tt[-1, :, :]
        z_t_original = z_t[:, :params["latent_dim"]]
        z_s = z_s[:, :params["latent_dim"]]
        sc_target_predict = vae.generate(z_t_predict)
        sc_target_original = vae.generate(z_t_original)
        sc_source = vae.generate(z_s)
        n_s = sc_source.shape[0]
        n_t = sc_target_original.shape[0]
        
        cost_x = np.linalg.norm(sc_target_predict - x_s, axis=1)
        cost_df['cost_x'] = cost_x
        
        mask = (adata.obs['age'] == list_ages[args.y_t]) 
        label_t = adata.obs.loc[mask, args.annot_level]
        anno_predict = vae.annotation(z_t_original, label_t, z_t_predict)
        
        x =  np.concatenate([sc_source, sc_target_original, sc_target_predict], axis=0)
        embedding_1 = umap.UMAP(n_neighbors=15, min_dist=0.5, random_state=10, metric='euclidean').fit_transform(x)

        xx_s = embedding_1[:n_s, :]
        xx_t = embedding_1[n_s:n_s+n_t, :]
        xx_tt = embedding_1[n_s+n_t:, :]
        
        fontsize = 6
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        ax[0].scatter(xx_s[:, 0], xx_s[:, 1], s=1, color='blue', marker='o', alpha=0.5, label=reverse_age_mapping[args.y_s])
        ax[0].scatter(xx_t[:, 0], xx_t[:, 1], s=1, color='red', marker='^', alpha=0.5, label=reverse_age_mapping[args.y_t])
        ax[0].legend(fontsize=fontsize)
        ax[1].scatter(xx_s[:, 0], xx_s[:, 1], s=1, color='blue', marker='o', alpha=0.5, label=reverse_age_mapping[args.y_s])
        ax[1].scatter(xx_tt[:, 0], xx_tt[:, 1], s=1, color='orange', marker='s', alpha=0.5, label='Transported cells')
        ax[1].legend(fontsize=fontsize)
        
        # Plot arrows from blue to orange dots
        # for i in range(n_s):
        #     x_blue, y_blue = xx_s[i]
        #     x_orange, y_orange = xx_tt[i]

        #     # Calculate the change in x and y
        #     dx = x_orange - x_blue
        #     dy = y_orange - y_blue
            
        #     ax[1].arrow(x_blue, y_blue, dx, dy,
        #                 color='grey',
        #                 linewidth=0.5,
        #                 head_width=0.1,
        #                 head_length=0.15,
        #                 linestyle='--',
        #                 length_includes_head=True, # Ensure the arrow's length includes the head
        #                 alpha=0.7)
        
        fig.savefig(selected_model_file + f'/UMAP_x_{ccf}_{args.y_s}_{args.y_t}.png', dpi=300, bbox_inches='tight')
        
        fig, ax = plt.subplots(figsize=(6, 5))
        all_types = np.unique(np.concatenate([type_s, type_t, np.array(n_s*['Transported'])]))
        colors = plt.cm.get_cmap('tab20', len(all_types))
        type_to_color = {cell_type: colors(i) for i, cell_type in enumerate(all_types)}

        unique_type_s = np.unique(type_s)
        for cell_type in unique_type_s:
            idx = type_s == cell_type
            ax.scatter(xx_s[idx, 0], xx_s[idx, 1], s=1, marker='o', color=type_to_color[cell_type], alpha=0.5, label=f'Source: {cell_type}')
            
        unique_type_t = np.unique(type_t)
        for cell_type in unique_type_t:
            idx = type_t == cell_type
            ax.scatter(xx_t[idx, 0], xx_t[idx, 1], s=1, marker='^', color=type_to_color[cell_type], alpha=0.5, label=f'Target: {cell_type}')
        
        ax.scatter(xx_tt[:, 0], xx_tt[:, 1], s=1, color=type_to_color['Transported'], marker='s', alpha=0.5, label='Transported cells')
        # ax.legend(title='Cell Types', bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0., fontsize=fontsize)
        # fig.tight_layout()
        fig.savefig(selected_model_file + f'/UMAP_x_{ccf}_colored_{args.y_s}_{args.y_t}.png', dpi=300, bbox_inches='tight')
        
        z = np.concatenate([z_s, z_t_original, z_t_predict], axis=0)
        embedding_2 = umap.UMAP(n_neighbors=15, min_dist=0.5, random_state=10, metric='euclidean').fit_transform(z)
        
        xx_s = embedding_2[:n_s, :]
        xx_t = embedding_2[n_s:n_s+n_t, :]
        xx_tt = embedding_2[n_s+n_t:, :]
        
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        ax[0].scatter(xx_s[:, 0], xx_s[:, 1], s=1, color='blue', marker='o', alpha=0.5, label=reverse_age_mapping[args.y_s])
        ax[0].scatter(xx_t[:, 0], xx_t[:, 1], s=1, color='red', marker='^', alpha=0.5, label=reverse_age_mapping[args.y_t])
        ax[0].legend(fontsize=fontsize)
        ax[1].scatter(xx_s[:, 0], xx_s[:, 1], s=1, color='blue', marker='o', alpha=0.5, label=reverse_age_mapping[args.y_s])
        ax[1].scatter(xx_tt[:, 0], xx_tt[:, 1], s=1, color='orange', marker='s', alpha=0.5, label='Transported cells')
        ax[1].legend(fontsize=fontsize)
        
        # Plot arrows from blue to orange dots
        # for i in range(n_s):
        #     x_blue, y_blue = xx_s[i]
        #     x_orange, y_orange = xx_tt[i]

        #     # Calculate the change in x and y
        #     dx = x_orange - x_blue
        #     dy = y_orange - y_blue
            
        #     ax[1].arrow(x_blue, y_blue, dx, dy,
        #                 color='grey',
        #                 linewidth=0.5,
        #                 head_width=0.1,
        #                 head_length=0.15,
        #                 linestyle='--',
        #                 length_includes_head=True, # Ensure the arrow's length includes the head
        #                 alpha=0.7)
            
        fig.savefig(selected_model_file + f'/UMAP_z_{ccf}_{args.y_s}_{args.y_t}.png', dpi=300, bbox_inches='tight')
        
        cost_df.to_csv(selected_model_file + f'/costs_{args.y_s}_{args.y_t}.csv', index=False)
        
        df_ = adata.obs.merge(cost_df, on='cell_id')
        df_[args.annot_level + '_aged'] = anno_predict
        adata_tmp = ad.AnnData(X=csr_matrix(sc_target_predict), obs=df_, var=adata.var)
        adata_tmp.write(config['paths']['data_path'] / dataset / f'transfer_cells_{ccf}_{args.y_s}_{args.y_t}.h5ad')
        


if __name__ == "__main__":
    args = parser.parse_args()
    main(args)