import sys
sys.path.append("/home/anon/branched_diffusion/src")
import importlib
import feature.molecule_dataset as molecule_dataset
import model.graph_adj_x_diffusion as graph_adj_x_diffusion
import model.graph_net as graph_net
import numpy as np
import torch
import os
import json

task_inds = [int(x) for x in sys.argv[1:]]
if not task_inds:
	task_inds = None

# Define device
if torch.cuda.is_available():
	DEVICE = "cuda"
else:
	DEVICE = "cpu"


def class_time_to_branch(c, t):
    """
    Given a class and a time (both scalars), return the
    corresponding branch index.
    """
    for i, branch_def in enumerate(branch_defs):
        if c in branch_def[0] and t >= branch_def[1] and t <= branch_def[2]:
            return i
    raise ValueError("Undefined class and time")
        
def class_time_to_branch_tensor(c, t):
    """
    Given tensors of classes and a times, return the
    corresponding branch indices as a tensor.
    """
    return torch.tensor([
        class_time_to_branch(c_i, t_i) for c_i, t_i in zip(c, t)
    ], device=DEVICE)

def import_classes_branch_points(json_path):
	with open(json_path, "r") as f:
		d = json.load(f)
		return d["classes"], \
			[(tuple(trip[0]), trip[1], trip[2]) for trip in d["branches"]]


# Define model path
model_base_path = "/gstore/scratch/u/tsenga5/branched_diffusion/models/trained_models/"

# Define classes and branches
branch_points_dir = "/home/anon/branched_diffusion/data/config/classes_branch_points/zinc250k/"
classes, branch_defs_15 = import_classes_branch_points(os.path.join(branch_points_dir, "c01_b15.json"))

all_dataset_types = ("bicyclicity", "aromaticity", "elements", "num_cycles", "macrocyclicity", "heteroaromaticity")
dataset_types = [all_dataset_types[i] for i in task_inds]

for dataset_type in dataset_types:
	for branch_defs in (branch_defs_15,):
		# Define model path
		dir_name = "zinc250k_continuous_%s_b%s" % (dataset_type, branch_defs[0][1])
		model_dir = os.path.join(model_base_path, dir_name)

		# Define dataset
		if dataset_type == "num_cycles":
			dataset = molecule_dataset.ZINCDataset(
				label_method=dataset_type, nums_to_label=[0, 1]
			)
			num_epochs = 200
		else:
			dataset = molecule_dataset.ZINCDataset(label_method=dataset_type)
			num_epochs = 50

		# Limit classes
		inds = np.isin(dataset.target, classes)
		dataset.all_smiles = dataset.all_smiles[inds]
		dataset.target = dataset.target[inds]
		
		# Create dataloader
		data_loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0)
		input_shape = next(iter(data_loader))[0].shape[1:]

		# Define the SDE
		sde = graph_adj_x_diffusion.AXJointSDE(0.1, 1, 0.2, 1, input_shape)
		t_limit = 1

		os.environ["MODEL_DIR"] = model_dir
		importlib.reload(graph_adj_x_diffusion)  # Reimport AFTER setting environment
		
		model = graph_net.GraphJointNetwork(
		    len(branch_defs), t_limit,
		    a_shared_layers=[True, True, True, True, True, False, False],
		    x_shared_layers=[True, True, False]
		).to(DEVICE)

		graph_adj_x_diffusion.train_ex.run(
		    "train_branched_model",
		    config_updates={
		        "model": model,
		        "sde": sde,
		        "data_loader": data_loader,
		        "class_time_to_branch_index": class_time_to_branch_tensor,
		        "num_epochs": num_epochs,
		        "learning_rate": 1e-3,
		        "t_limit": t_limit,
		        "loss_weighting_type": "empirical_norm"
		    }
		)
		
