import sys
sys.path.append("/home/anon/branched_diffusion/src")
import importlib
import model.sdes as sdes 
import model.train_continuous_model as train_continuous_model
import model.image_unet as image_unet
import numpy as np
import torch
import torchvision
import os
import json

# Define device
if torch.cuda.is_available():
	DEVICE = "cuda"
else:
	DEVICE = "cpu"


def class_time_to_branch(c, t):
    """
    Given a class and a time (both scalars), return the
    corresponding branch index.
    """
    for i, branch_def in enumerate(branch_defs):
        if c in branch_def[0] and t >= branch_def[1] and t <= branch_def[2]:
            return i
    raise ValueError("Undefined class and time")
        
def class_time_to_branch_tensor(c, t):
    """
    Given tensors of classes and a times, return the
    corresponding branch indices as a tensor.
    """
    return torch.tensor([
        class_time_to_branch(c_i, t_i) for c_i, t_i in zip(c, t)
    ], device=DEVICE)

def class_to_class_index_tensor(c):
    """
    Given a tensor of classes, return the corresponding class indices
    as a tensor.
    """
    return torch.argmax(
        (c[:, None] == torch.tensor(classes, device=c.device)).int(), dim=1
    ).to(DEVICE)

def import_classes_branch_points(json_path):
	with open(json_path, "r") as f:
		d = json.load(f)
		return d["classes"], \
			[(tuple(trip[0]), trip[1], trip[2]) for trip in d["branches"]]


# Define model path
model_base_path = "/gstore/scratch/u/tsenga5/branched_diffusion/models/trained_models/"

# Define classes and branches
branch_points_dir = "/home/anon/branched_diffusion/data/config/classes_branch_points/mnist/"
classes_049, branch_defs_049 = import_classes_branch_points(os.path.join(branch_points_dir, "049.json"))
classes_all, branch_defs_all = import_classes_branch_points(os.path.join(branch_points_dir, "all_digits.json"))

for model_type, classes, branch_defs, num_epochs, kwargs in [
	("branched", classes_all, branch_defs_all, 200, None),
	("branched", classes_049, branch_defs_049, 100, None),
	("labelguided", classes_all, branch_defs_all, 200, None),
	("labelguided", classes_049, branch_defs_049, 100, {"extra_classes": 1})
]:
	# Define model path
	dir_name = "mnist_continuous_%s_%dclasses" % (model_type, len(classes))
	model_dir = os.path.join(model_base_path, dir_name)

	# Create dataset
	dataset = torchvision.datasets.MNIST(
	    "/data/anon/datasets", train=True,
		transform=(lambda img: (np.asarray(img)[None] / 256 * 2) - 1)
	)
	
	# Limit classes
	inds = np.isin(dataset.targets, classes)
	dataset.data = dataset.data[inds]
	dataset.targets = dataset.targets[inds]
	
	# Create data loader
	data_loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)
	input_shape = next(iter(data_loader))[0].shape[1:]
	
	# Define the SDE
	sde = sdes.VariancePreservingSDE(0.1, 20, input_shape)
	t_limit = 1

	os.environ["MODEL_DIR"] = model_dir
	importlib.reload(train_continuous_model)  # Reimport AFTER setting environment
	
	if model_type == "branched":
		model = image_unet.MultitaskMNISTUNetTimeConcat(
		    len(branch_defs), t_limit=t_limit
		).to(DEVICE)
	
		train_continuous_model.train_ex.run(
		    "train_branched_model",
		    config_updates={
		        "model": model,
		        "sde": sde,
		        "data_loader": data_loader,
		        "class_time_to_branch_index": class_time_to_branch_tensor,
		        "num_epochs": num_epochs,
		        "learning_rate": 0.001,
		        "t_limit": t_limit,
		        "loss_weighting_type": "empirical_norm"
		    }
		)
	elif model_type == "labelguided":
		extra_classes = kwargs["extra_classes"] if kwargs and "extra_classes" in kwargs else 0
		model = image_unet.LabelGuidedMNISTUNetTimeConcat(
		    len(classes) + extra_classes, t_limit=t_limit
		).to(DEVICE)
	
		train_continuous_model.train_ex.run(
		    "train_label_guided_model",
		    config_updates={
		        "model": model,
		        "sde": sde,
		        "data_loader": data_loader,
				"class_to_class_index": class_to_class_index_tensor,
		        "num_epochs": num_epochs,
		        "learning_rate": 0.001,
		        "t_limit": t_limit,
		        "loss_weighting_type": "empirical_norm"
		    }
		)
