import argparse


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-v', "--version",)    

    args = parser.parse_args()
    version = args.version

seed_i = int(version)
print("SEED: ", seed_i)


import warnings

warnings.filterwarnings("ignore", ".*does not have many workers.*")
warnings.filterwarnings("ignore", category=FutureWarning)

import time
import os
import json
from pathlib import Path
from os import environ
import pytorch_lightning as pl
import torch
from bicycle.dictlogger import DictLogger
from datetime import datetime
from bicycle.model import BICYCLE
from bicycle.utils.data import (
    create_data,
    create_loaders,
    get_diagonal_mask,
    compute_inits,
)
from bicycle.utils.general import get_full_name
from bicycle.utils.plotting import plot_training_results
from pytorch_lightning.callbacks import RichProgressBar, StochasticWeightAveraging
from bicycle.callbacks import ModelCheckpoint, GenerateCallback, MyLoggerCallback, CustomModelCheckpoint
import numpy as np
import yaml

SEED = seed_i
pl.seed_everything(SEED)
torch.set_float32_matmul_precision("high")
device = torch.device("cpu")

# define user directory here
user_dir = 't_seed_'+ str(seed_i)
MODEL_PATH = Path(os.path.join(user_dir, "models"))
PLOT_PATH = Path(os.path.join(user_dir, "plots"))
MODEL_PATH.mkdir(parents=True, exist_ok=True)
PLOT_PATH.mkdir(parents=True, exist_ok=True)



# Model
n_factors = 100
rank_w_cov_factor = n_factors # Same as dictys: #min(TFs, N_GENES-1)
perfect_interventions = True

# LEARNING
lr = 1e-5
batch_size = 4096 #128
validation_size = 0.2
USE_INITS = False #True
use_encoder = False
n_epochs = 51000
early_stopping = False
early_stopping_patience = 500
early_stopping_min_delta = 0.01
optimizer = "adam"
gradient_clip_val = 50.0
swa = 100
x_distribution = "Multinomial"
x_distribution_kwargs = {}

# MODEL
lyapunov_penalty = True
GPU_DEVICE = 0
plot_epoch_callback = 1000
use_latents = False
# RESULTS
name_prefix = f"boolODE-TEST-2_{seed_i}_{use_encoder}_{batch_size}_{lyapunov_penalty}"
SAVE_PLOT = True
CHECKPOINTING = False
VERBOSE_CHECKPOINTING = False
OVERWRITE = True
# REST
check_val_every_n_epoch = 1
log_every_n_steps = 1


covariates = None
correct_covariates = False

# N_SAMPLES X N_GENES DATA MATRIX
samples = torch.tensor(np.load("PerturbSDE/simulated_data/refNet.npy"))

print('SAMPLES.shape:',samples.shape)

n_samples = samples.shape[0]
n_genes = samples.shape[1]

# ALL CELLS ARE UNPERTURBED
# regime = torch.zeros( (samples.shape[0],1), dtype = torch.int)

# ALL CELLS ARE PERTURBED
with open("PerturbSDE/simulated_data/BoolODE_GeneToIndex.json") as f:
    d = json.load(f)

gt_interv_idx = [d[i] for i in d]
gt_interv = np.zeros(samples.size())
for i in range(len(gt_interv_idx)):
    for j in gt_interv_idx[i]:
        gt_interv[i,j] = 1
gt_interv = gt_interv.astype(int)
gt_interv = torch.from_numpy(gt_interv).float()

# regimes
EXTRAPOLATTION = 20
regimes = torch.from_numpy(np.array([np.random.choice(gt_interv_idx[i], EXTRAPOLATTION) for i in range(samples.shape[0])]).flatten())
samples_extrapolated = []
for i in samples:
    for _ in range(EXTRAPOLATTION):
        samples_extrapolated.append(torch.Tensor.tolist(i))
samples = torch.from_numpy(np.array(samples_extrapolated))

n_conditions = len(gt_interv.unique(dim=0))
# Use single-gene perturbations for training and validation and put 
train_gene_ko = [str(x) for x in range(n_genes)]
# Regimes for training (and validation/hyperparameter tuning)
train_regimes = list()

# Regimes with dual perturbations to hold out for testing
test_regimes = list()

for c in range(n_conditions):
    if gt_interv[:,c].sum() > 1.5:
        test_regimes.append(c)
    else:
        train_regimes.append(c)


# Generate data loaders

validation_size = 0.2
batch_size = 4096

train_loader, validation_loader, test_loader = create_loaders(
    samples,
    regimes,
    validation_size,
    batch_size,
    SEED,
    train_regimes,
    test_regimes
)

covariates = None

print("Training data:")
print(f"- Number of training samples: {len(train_loader.dataset)}")
print("Training regimes:", train_regimes)
if validation_size > 0:
    print(f"- Number of validation samples: {len(validation_loader.dataset)}")
if len(test_regimes) > 0:
    print(f"- Number of test samples: {len(test_loader.dataset)}")
    print("Test regimes:", test_regimes)


n_factors = 0

# The Norman data have CRISPRa interventions - use dCas9
intervention_type_inference = "dCas9"

SEED = 1
pl.seed_everything(SEED)
torch.set_float32_matmul_precision("high")
device = torch.device("cpu")

#
# Settings
#

# TRAINING
lr = 1e-3
USE_INITS = False
use_encoder = False
n_epochs = 51000
early_stopping = False
early_stopping_patience = 500
early_stopping_min_delta = 0.01
# Maybe this helps to stop the loss from growing late during training (see current version
# of Plot_Diagnostics.ipynb)
optimizer = "adam" #"rmsprop" #"adam"
optimizer_kwargs = {}
#    "betas": [0.5,0.9] # Faster decay for estimates of gradient and gradient squared
#}
gradient_clip_val = 1e-3
GPU_DEVICE = 0
plot_epoch_callback = 500
validation_size = 0.2
lyapunov_penalty = True
swa = 250
n_epochs_pretrain_latents = 1000#10000

# MODEL
x_distribution = "Multinomial"
x_distribution_kwargs = {}
model_T = 1.0
learn_T = False
use_latents = True
perfect_interventions = True
rank_w_cov_factor = n_genes # Fitting full covariance matrices for multivariate normals

# RESULTS
name_prefix = f"2TEST_CHECKPOINTS_SERGIO_Demo_optim{optimizer}_b1_0.5_b2_0.9_pretrain_epochs{n_epochs_pretrain_latents}_GRAD-CLIP_INF:{intervention_type_inference}-slow_lr_{use_encoder}_{batch_size}_{lyapunov_penalty}"
SAVE_PLOT = True
CHECKPOINTING = True
VERBOSE_CHECKPOINTING = True
OVERWRITE = False
# REST
n_samples_total = samples.shape[0]
check_val_every_n_epoch = 1
log_every_n_steps = 1

# Create Mask
mask = get_diagonal_mask(n_genes, device)

if n_factors > 0:
    mask = None


if USE_INITS:
    init_tensors = compute_inits(train_loader.dataset, rank_w_cov_factor, n_contexts)

device = torch.device("cpu")
gt_interv = gt_interv.to(device)
n_genes = samples.shape[1]

if covariates is not None and correct_covariates:
    covariates = covariates.to(device)

for scale_kl in [1.0]:  # 1
    for scale_l1 in [1.0]: # 1
        for scale_spectral in [0.0]: # 1.0
            for scale_lyapunov in [0.1]: # 0.1
                file_dir = get_full_name(
                    name_prefix,
                    len(test_regimes),
                    SEED,
                    lr,
                    n_genes,
                    scale_l1,
                    scale_kl,
                    scale_spectral,
                    scale_lyapunov,
                    gradient_clip_val,
                    swa,
                )

                # If final plot or final model exists: do not overwrite by default
                print("Checking Model and Plot files...")
                final_file_name = os.path.join(MODEL_PATH, file_dir, "last.ckpt")
                final_plot_name = os.path.join(PLOT_PATH, file_dir, "last.png")
                
                # Save simulated data for inspection and debugging
                final_data_path = os.path.join(PLOT_PATH, file_dir)
                
                if os.path.isdir(final_data_path):
                    print(final_data_path, "exists")
                else:
                    print("Creating", final_data_path)
                    os.mkdir(final_data_path)
                
                np.save(os.path.join(final_data_path,'check_samples.npy'), samples.detach().cpu().numpy())
                np.save(os.path.join(final_data_path,'check_regimes.npy'), regimes.detach().cpu().numpy())
                np.save(os.path.join(final_data_path,'check_gt_interv.npy'), gt_interv.detach().cpu().numpy())
                
                # labels = list(adata.var.index)
                
                # np.save(os.path.join(final_data_path,'labels.npy'), labels, allow_pickle=True)
                
                if (Path(final_file_name).exists() & SAVE_PLOT & ~OVERWRITE) | (
                    Path(final_plot_name).exists() & CHECKPOINTING & ~OVERWRITE
                ):
                    print("- Files already exists, skipping...")
                    # continue
                else:
                    print("- Not all files exist, fitting model...")
                    print("  - Deleting dirs")
                    # Delete directories of files
                    if Path(final_file_name).exists():
                        print(f"  - Deleting {final_file_name}")
                        # Delete all files in os.path.join(MODEL_PATH, file_name)
                        for f in os.listdir(os.path.join(MODEL_PATH, file_dir)):
                            os.remove(os.path.join(MODEL_PATH, file_dir, f))
                    if Path(final_plot_name).exists():
                        print(f"  - Deleting {final_plot_name}")
                        for f in os.listdir(os.path.join(PLOT_PATH, file_dir)):
                            os.remove(os.path.join(PLOT_PATH, file_dir, f))

                    print("  - Creating dirs")
                    # Create directories
                    Path(os.path.join(MODEL_PATH, file_dir)).mkdir(parents=True, exist_ok=True)
                    Path(os.path.join(PLOT_PATH, file_dir)).mkdir(parents=True, exist_ok=True)

                model = BICYCLE(
                    lr,
                    gt_interv,
                    n_genes,
                    n_samples=n_samples_total,
                    lyapunov_penalty=lyapunov_penalty,
                    perfect_interventions=perfect_interventions,
                    rank_w_cov_factor=rank_w_cov_factor,
                    init_tensors=init_tensors if USE_INITS else None,
                    optimizer=optimizer,
                    optimizer_kwargs = optimizer_kwargs,
                    device=device,
                    scale_l1=scale_l1,
                    scale_lyapunov=scale_lyapunov,
                    scale_spectral=scale_spectral,
                    scale_kl=scale_kl,
                    early_stopping=early_stopping,
                    early_stopping_min_delta=early_stopping_min_delta,
                    early_stopping_patience=early_stopping_patience,
                    early_stopping_p_mode=True,
                    x_distribution=x_distribution,
                    x_distribution_kwargs=x_distribution_kwargs,
                    mask=mask,
                    use_encoder=use_encoder,
                    train_gene_ko=train_regimes,
                    test_gene_ko=test_regimes,
                    use_latents=use_latents,
                    covariates=covariates,
                    n_factors = n_factors,
                    intervention_type = intervention_type_inference,
                    T = model_T,
                    learn_T = learn_T
                )
                model.to(device)

                dlogger = DictLogger()
                loggers = [dlogger]

                callbacks = [
                    RichProgressBar(refresh_rate=1),
                    GenerateCallback(
                        final_plot_name,
                        plot_epoch_callback=plot_epoch_callback,
                        # labels=labels
                    ),
                ]
                if swa > 0:
                    callbacks.append(StochasticWeightAveraging(0.01, swa_epoch_start=swa))
                if CHECKPOINTING:
                    Path(os.path.join(MODEL_PATH, file_dir)).mkdir(parents=True, exist_ok=True)
                    
                    print('Checkpointing to:',MODEL_PATH)
                    
                    callbacks.append(
                        CustomModelCheckpoint(
                            dirpath=os.path.join(MODEL_PATH, file_dir),
                            filename="{epoch}",
                            save_last=True,
                            save_top_k=1,
                            verbose=VERBOSE_CHECKPOINTING,
                            monitor="valid_loss",
                            mode="min",
                            save_weights_only=True,
                            start_after=1000,
                            save_on_train_epoch_end=True,
                            every_n_epochs=500,
                        )
                    )
                    callbacks.append(MyLoggerCallback(dirpath=os.path.join(MODEL_PATH, file_dir)))

                trainer = pl.Trainer(
                    max_epochs=n_epochs,
                    accelerator="cpu",  # if str(device).startswith("cuda") else "cpu",
                    logger=loggers,
                    log_every_n_steps=log_every_n_steps,
                    enable_model_summary=True,
                    enable_progress_bar=True,
                    enable_checkpointing=CHECKPOINTING,
                    check_val_every_n_epoch=check_val_every_n_epoch,
                    devices=1,  # if str(device).startswith("cuda") else 1,
                    num_sanity_val_steps=0,
                    callbacks=callbacks,
                    gradient_clip_val=gradient_clip_val,
                    default_root_dir=str(MODEL_PATH),
                    gradient_clip_algorithm="value",
                    deterministic=False, #"warn",
                )
                
                '''print('Optimizing learning rates')
                
                tuner = Tuner(trainer)

                # Run learning rate finder
                lr_finder = tuner.lr_find(model)

                # Results can be found in
                print(lr_finder.results)

                # Plot with
                fig = lr_finder.plot(suggest=True)
                fig.save('lr_finder.png')

                # Pick point based on plot, or get suggestion
                new_lr = lr_finder.suggestion()
                
                print('Using learning rate of:',new_lr)

                # update hparams of the model
                model.hparams.lr = new_lr'''

                
                if use_latents and n_epochs_pretrain_latents > 0:
                    
                    pretrain_callbacks = [
                        RichProgressBar(refresh_rate=1),
                        GenerateCallback(
                            str(Path(final_plot_name).with_suffix("")) + '_pretrain', 
                            plot_epoch_callback=plot_epoch_callback,
                            # labels=labels
                        ),                    
                    ]
                    
                    if swa > 0:
                        pretrain_callbacks.append(StochasticWeightAveraging(0.01, swa_epoch_start=swa))
    
                    pretrain_callbacks.append(MyLoggerCallback(dirpath=os.path.join(MODEL_PATH, file_dir)))
                    
                    pretrainer = pl.Trainer(
                        max_epochs=n_epochs_pretrain_latents,
                        accelerator="cpu",  # if str(device).startswith("cuda") else "cpu",
                        logger=loggers,
                        log_every_n_steps=log_every_n_steps,
                        enable_model_summary=True,
                        enable_progress_bar=True,
                        enable_checkpointing=CHECKPOINTING,
                        check_val_every_n_epoch=check_val_every_n_epoch,
                        devices=1,  # if str(device).startswith("cuda") else 1,
                        num_sanity_val_steps=0,
                        callbacks=pretrain_callbacks,
                        gradient_clip_val=gradient_clip_val,
                        default_root_dir=str(MODEL_PATH),
                        gradient_clip_algorithm="value",
                        deterministic=False, #"warn",
                    )
                    
                    print('PRETRAINING LATENTS!')
                    start_time = time.time()
                    model.train_only_likelihood = True
                    # assert False
                    pretrainer.fit(model, train_loader, validation_loader)
                    end_time = time.time()
                    model.train_only_likelihood = False
                
                # try:
                start_time = time.time()
                # assert False
                trainer.fit(model, train_loader, validation_loader)
                end_time = time.time()
                print(f"Training took {end_time - start_time:.2f} seconds")

                plot_training_results(
                    trainer,
                    model,
                    model.beta.detach().cpu().numpy(),
                    beta,
                    scale_l1,
                    scale_kl,
                    scale_spectral,
                    scale_lyapunov,
                    final_plot_name,
                    callback=False,
                )
                # except Exception as e:
                #     # Write Exception to file
                #     report_path = os.path.join(MODEL_PATH, file_dir, "report.yaml")
                #     # Write yaml
                #     with open(report_path, "w") as outfile:
                #         yaml.dump({"exception": str(e)}, outfile, default_flow_style=False)
            