import sys
sys.path.append("/home/anon/branched_diffusion/src")
import importlib
import model.sdes as sdes 
import model.train_continuous_model as train_continuous_model
import model.table_dnn as table_dnn 
import numpy as np
import torch
import os
import json

# Define device
if torch.cuda.is_available():
	DEVICE = "cuda"
else:
	DEVICE = "cpu"


def class_time_to_branch(c, t):
    """
    Given a class and a time (both scalars), return the
    corresponding branch index.
    """
    for i, branch_def in enumerate(branch_defs):
        if c in branch_def[0] and t >= branch_def[1] and t <= branch_def[2]:
            return i
    raise ValueError("Undefined class and time")
        
def class_time_to_branch_tensor(c, t):
    """
    Given tensors of classes and a times, return the
    corresponding branch indices as a tensor.
    """
    return torch.tensor([
        class_time_to_branch(c_i, t_i) for c_i, t_i in zip(c, t)
    ], device=DEVICE)

def import_classes_branch_points(json_path):
	with open(json_path, "r") as f:
		d = json.load(f)
		return d["classes"], \
			[(tuple(trip[0]), trip[1], trip[2]) for trip in d["branches"]]


# Define model path
model_dir_1 = "/data/anon/branched_diffusion/models/trained_models/letters_continuous_allletters"
model_dir_2 = "/data/anon/branched_diffusion/models/trained_models/letters_continuous_allletters_labelguided"

# Define classes and branches
branch_points_dir = "/home/anon/branched_diffusion/data/config/classes_branch_points/letters/"
letters, branch_defs = import_classes_branch_points(os.path.join(branch_points_dir, "all_letters.json"))
class_to_letter = dict(enumerate(letters))
letter_to_class = {v : k for k, v in class_to_letter.items()}

branch_defs = [
    (tuple(map(lambda l: letter_to_class[l], trip[0])), trip[1], trip[2])
    for trip in branch_defs
]

# Create dataset
class LetterDataset(torch.utils.data.Dataset):
    def __init__(self):
        data_path = "/data/anon/branched_diffusion/data/letter_recognition/letter-recognition.data"
        
        data = []
        targets = []
        with open(data_path, "r") as f:
            for line in f:
                tokens = line.strip().split(",")
                targets.append(tokens[0])
                data.append(np.array(list(map(int, tokens[1:]))))
        self.data = np.stack(data)
        self.targets = np.array([letter_to_class[l] for l in targets])
        
        # Center/normalize the data
        self.data = (self.data - np.mean(self.data, axis=0, keepdims=True)) / \
            np.std(self.data, axis=0, keepdims=True)
        
    def __getitem__(self, index):
        return torch.tensor(self.data[index]).float(), self.targets[index]
        
    def __len__(self):
        return len(self.targets)
    
dataset = LetterDataset()

# Limit classes
inds = np.isin(dataset.targets, classes)
dataset.data = dataset.data[inds]
dataset.targets = dataset.targets[inds]

# Create data loader
data_loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)
input_shape = next(iter(data_loader))[0].shape[1:]

# Define the SDE
sde = sdes.VariancePreservingSDE(0.1, 20, input_shape)
t_limit = 1

os.environ["MODEL_DIR"] = model_dir_1
importlib.reload(train_continuous_model)  # Reimport AFTER setting environment

model = table_dnn.MultitaskTabularNet(
    len(branch_defs), input_shape[0], t_limit=t_limit
).to(DEVICE)

train_continuous_model.train_ex.run(
    "train_branched_model",
    config_updates={
        "model": model,
        "sde": sde,
        "data_loader": data_loader,
        "class_time_to_branch_index": class_time_to_branch_tensor,
        "num_epochs": 100,
        "learning_rate": 0.001,
        "t_limit": t_limit,
        "loss_weighting_type": "empirical_norm"
    }
)

os.environ["MODEL_DIR"] = model_dir_2
importlib.reload(train_continuous_model)  # Reimport AFTER setting environment

model = image_unet.LabelGuidedMNISTUNetTimeConcat(
    len(classes) + 1, t_limit=t_limit
).to(DEVICE)

train_continuous_model.train_ex.run(
    "train_label_guided_model",
    config_updates={
        "model": model,
        "sde": sde,
        "data_loader": data_loader,
		"class_to_class_index": class_to_class_index_tensor,
        "num_epochs": 100,
        "learning_rate": 0.001,
        "t_limit": t_limit,
        "loss_weighting_type": "empirical_norm"
    }
)
