import sys
sys.path.append("/home/anon/branched_diffusion/src")
import importlib
import model.sdes as sdes 
import model.train_continuous_model as train_continuous_model
import model.scrna_ae as scrna_ae
import feature.scrna_dataset as scrna_dataset
import numpy as np
import torch
import os
import json

task_inds = [int(x) for x in sys.argv[1:]]
if not task_inds:
	task_inds = None

# Define device
if torch.cuda.is_available():
	DEVICE = "cuda"
else:
	DEVICE = "cpu"


def class_time_to_branch(c, t):
    """
    Given a class and a time (both scalars), return the
    corresponding branch index.
    """
    for i, branch_def in enumerate(branch_defs):
        if c in branch_def[0] and t >= branch_def[1] and t <= branch_def[2]:
            return i
    raise ValueError("Undefined class and time")
        
def class_time_to_branch_tensor(c, t):
    """
    Given tensors of classes and a times, return the
    corresponding branch indices as a tensor.
    """
    return torch.tensor([
        class_time_to_branch(c_i, t_i) for c_i, t_i in zip(c, t)
    ], device=DEVICE)

def class_to_class_index_tensor(c):
    """
    Given a tensor of classes, return the corresponding class indices
    as a tensor.
    """
    return torch.argmax(
        (c[:, None] == torch.tensor(classes, device=c.device)).int(), dim=1
    ).to(DEVICE)

def import_classes_branch_points(json_path):
	with open(json_path, "r") as f:
		d = json.load(f)
		return d["classes"], \
			[(tuple(trip[0]), trip[1], trip[2]) for trip in d["branches"]]

# Define model path
model_base_path = "/gstore/scratch/u/tsenga5/branched_diffusion/models/trained_models/"

# Define classes and branches
branch_points_dir = "/home/anon/branched_diffusion/data/config/classes_branch_points/scrna_covid_flu/"
classes_01, branch_defs_01 = import_classes_branch_points(os.path.join(branch_points_dir, "01.json"))
classes_redset, branch_defs_redset = import_classes_branch_points(os.path.join(branch_points_dir, "redset.json"))

# Define data files
data_file = "/data/anon/branched_diffusion/data/scrna/covid_flu/processed/covid_flu_processed_reduced_genes.h5"
autoencoder_path = "/data/anon/branched_diffusion/models/trained_models/scrna_vaes/covid_flu/covid_flu_processed_reduced_genes_ldvae_d%d/"

tasks = [
	("branched", classes_redset, branch_defs_redset, False, None, None),
	("branched", classes_redset, branch_defs_redset, True, 50, None),
	("branched", classes_redset, branch_defs_redset, True, 100, None),
	("branched", classes_redset, branch_defs_redset, True, 200, None),
	("labelguided", classes_redset, branch_defs_redset, False, None, None),
	("labelguided", classes_redset, branch_defs_redset, True, 50, None),
	("labelguided", classes_redset, branch_defs_redset, True, 100, None),
	("labelguided", classes_redset, branch_defs_redset, True, 200, None),
	("branched", classes_01, branch_defs_01, False, None, None),
	("branched", classes_01, branch_defs_01, True, 50, None),
	("branched", classes_01, branch_defs_01, True, 100, None),
	("branched", classes_01, branch_defs_01, True, 200, None),
	("labelguided", classes_01, branch_defs_01, False, None, {"extra_classes": 1}),
	("labelguided", classes_01, branch_defs_01, True, 50, {"extra_classes": 1}),
	("labelguided", classes_01, branch_defs_01, True, 100, {"extra_classes": 1}),
	("labelguided", classes_01, branch_defs_01, True, 200, {"extra_classes": 1}),
	("labelguided-add", classes_redset, branch_defs_redset, False, None, None),
	("labelguided-add", classes_redset, branch_defs_redset, True, 50, None),
	("labelguided-add", classes_redset, branch_defs_redset, True, 100, None),
	("labelguided-add", classes_redset, branch_defs_redset, True, 200, None),
	("labelguided-add", classes_01, branch_defs_01, False, None, {"extra_classes": 1}),
	("labelguided-add", classes_01, branch_defs_01, True, 50, {"extra_classes": 1}),
	("labelguided-add", classes_01, branch_defs_01, True, 100, {"extra_classes": 1}),
	("labelguided-add", classes_01, branch_defs_01, True, 200, {"extra_classes": 1})
]
for task_i in task_inds:
	model_type, classes, branch_defs, latent, latent_size, kwargs = tasks[task_i]

	# Define model path
	dir_name = "scrna_covid_flu_continuous_%s_%dclasses" % (model_type, len(classes))
	if latent:
		dir_name += "_latent_d%d" % latent_size
	model_dir = os.path.join(model_base_path, dir_name)

	# Define dataset
	if latent_size:
		ae_path = autoencoder_path % latent_size
	else:
		ae_path = None
	dataset = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=ae_path)

	# Limit classes
	inds = np.isin(dataset.cell_cluster, classes)
	dataset.data = dataset.data[inds]
	dataset.cell_cluster = dataset.cell_cluster[inds]

	# Create dataloader
	data_loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0)
	input_shape = next(iter(data_loader))[0].shape[1:]

	# Define the SDE
	sde = sdes.VariancePreservingSDE(0.1, 5, input_shape)
	t_limit = 1

	os.environ["MODEL_DIR"] = model_dir
	importlib.reload(train_continuous_model)  # Reimport AFTER setting environment
	
	if model_type == "branched":
		model = scrna_ae.MultitaskResNet(
		    len(branch_defs), input_shape[0], t_limit=t_limit
		).to(DEVICE)
	
		train_continuous_model.train_ex.run(
		    "train_branched_model",
		    config_updates={
		        "model": model,
		        "sde": sde,
		        "data_loader": data_loader,
		        "class_time_to_branch_index": class_time_to_branch_tensor,
		        "num_epochs": 120,
		        "learning_rate": 0.001,
		        "t_limit": t_limit,
		        "loss_weighting_type": "empirical_norm"
		    }
		)
	elif model_type == "labelguided":
		extra_classes = kwargs["extra_classes"] if kwargs and "extra_classes" in kwargs else 0
		model = scrna_ae.LabelGuidedResNet(
		    len(classes) + extra_classes, input_shape[0], t_limit=t_limit
		).to(DEVICE)
	
		train_continuous_model.train_ex.run(
		    "train_label_guided_model",
		    config_updates={
		        "model": model,
		        "sde": sde,
		        "data_loader": data_loader,
				"class_to_class_index": class_to_class_index_tensor,
		        "num_epochs": 100,
		        "learning_rate": 0.001,
		        "t_limit": t_limit,
		        "loss_weighting_type": "empirical_norm"
		    }
		)
	elif model_type == "labelguided-add":
		extra_classes = kwargs["extra_classes"] if kwargs and "extra_classes" in kwargs else 0
		model = scrna_ae.LabelGuidedResNetAdd(
		    len(classes) + extra_classes, input_shape[0], t_limit=t_limit
		).to(DEVICE)
	
		train_continuous_model.train_ex.run(
		    "train_label_guided_model",
		    config_updates={
		        "model": model,
		        "sde": sde,
		        "data_loader": data_loader,
				"class_to_class_index": class_to_class_index_tensor,
		        "num_epochs": 100,
		        "learning_rate": 0.001,
		        "t_limit": t_limit,
		        "loss_weighting_type": "empirical_norm"
		    }
		)
