import sys
sys.path.append("/home/anon/branched_diffusion/src")
import importlib
import model.sdes as sdes 
import model.train_continuous_model as train_continuous_model
import model.image_unet as image_unet
import numpy as np
import torch
import torchvision
import os
import json

task_inds = [int(x) for x in sys.argv[1:]]
if not task_inds:
	task_inds = None

# Define device
if torch.cuda.is_available():
	DEVICE = "cuda"
else:
	DEVICE = "cpu"

def class_time_to_branch(c, t):
    """
    Given a class and a time (both scalars), return the
    corresponding branch index.
    """
    for i, branch_def in enumerate(branch_defs):
        if c in branch_def[0] and t >= branch_def[1] and t <= branch_def[2]:
            return i
    raise ValueError("Undefined class and time")
        
def class_time_to_branch_tensor(c, t):
    """
    Given tensors of classes and a times, return the
    corresponding branch indices as a tensor.
    """
    return torch.tensor([
        class_time_to_branch(c_i, t_i) for c_i, t_i in zip(c, t)
    ], device=DEVICE)

def import_classes_branch_points(json_path):
	with open(json_path, "r") as f:
		d = json.load(f)
		return d["classes"], \
			[(tuple(trip[0]), trip[1], trip[2]) for trip in d["branches"]]


# Define model path
model_base_path = "/gstore/scratch/u/tsenga5/branched_diffusion/models/trained_models/"

# Define classes and branches
branch_points_dir = "/home/anon/branched_diffusion/data/config/classes_branch_points/mnist/branch_point_discovery_variants"
classes = list(range(10))
all_branch_defs = []
for i in range(10):
	all_branch_defs.append(
		import_classes_branch_points(os.path.join(branch_points_dir, "trial_%d.json" % i))[1]
	)

for i in (task_inds if task_inds else range(len(all_branch_defs))):
	branch_defs = all_branch_defs[i]

	# Define model path
	dir_name = "mnist_continuous_branch_variation/trial_%d" % i
	model_dir = os.path.join(model_base_path, dir_name)

	# Create dataset
	dataset = torchvision.datasets.MNIST(
	    "/data/anon/datasets", train=True,
		transform=(lambda img: (np.asarray(img)[None] / 256 * 2) - 1)
	)
	
	# Limit classes
	inds = np.isin(dataset.targets, classes)
	dataset.data = dataset.data[inds]
	dataset.targets = dataset.targets[inds]
	
	# Create data loader
	data_loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)
	input_shape = next(iter(data_loader))[0].shape[1:]
	
	# Define the SDE
	sde = sdes.VariancePreservingSDE(0.1, 20, input_shape)
	t_limit = 1

	for _ in range(3):
		os.environ["MODEL_DIR"] = model_dir
		importlib.reload(train_continuous_model)  # Reimport AFTER setting environment
		
		model = image_unet.MultitaskMNISTUNetTimeConcat(
		    len(branch_defs), t_limit=t_limit
		).to(DEVICE)
	
		train_continuous_model.train_ex.run(
		    "train_branched_model",
		    config_updates={
		        "model": model,
		        "sde": sde,
		        "data_loader": data_loader,
		        "class_time_to_branch_index": class_time_to_branch_tensor,
		        "num_epochs": 150,
		        "learning_rate": 0.001,
		        "t_limit": t_limit,
		        "loss_weighting_type": "empirical_norm"
		    }
		)
