import torch
import numpy as np
import tqdm
import os
import sacred
import model.util as util

MODEL_DIR = os.environ.get(
    "MODEL_DIR",
    "/home/username/projects/masa4"
)

train_ex = sacred.Experiment("train")

train_ex.observers.append(
    sacred.observers.FileStorageObserver.create(MODEL_DIR)
)

''''
# Define device
if torch.cuda.is_available():
    DEVICE = "cuda:3"
else:
    DEVICE = "cpu"
'''

@train_ex.config
def config():
    # Number of training epochs
    num_epochs = 30

    # Learning rate
    learning_rate = 0.001


@train_ex.command
def train_model(
    model, sde, data_loader, model_type, num_epochs, learning_rate, _run,
    class_to_class_index=None, loss_weighting_type="empirical_norm",
    weight_func=None, t_limit=1, DEVICE = "cuda"
):
    """
    Trains a diffusion model using the given instantiated model and SDE object.
    Arguments:
        `model`: an instantiated score model which takes in x, t and predicts
            score
        `sde`: an SDE object
        `data_loader`: a DataLoader object that yields batches of data as a
            dictionary with keys: 1) "x" for the input; 2) "y" optionally for
            the class in a conditional model; 3) other keys to be passed as
            keyword arguments to the model
        `model_type`: either "conditional" or "unconditional"
        `num_epochs`: number of epochs to train for
        `learning_rate`: learning rate to use for training
        `class_to_class_index`: for "conditional" model types, a function that
            takes in B-tensors of class and maps to a B-tensor of class indices
        `loss_weighting_type`: method for weighting the loss; can be "ml" to
            weight by g^2, "expected_norm" to weight by the expected mean
            magnitude of the loss, "empirical_norm" to weight by the observed
            true norm, or None to do no weighting at all
        `weight_func`: if given, a function mapping a batch of inputs x0 to a
            broadcastable tensor of weights which will multiply into the loss
            for each predicted feature (in addition to any loss weighting
            specified by `loss_weighting_type`)
        `t_limit`: training will occur between time 0 and `t_limit`
    """
    assert model_type in ("conditional", "unconditional")

    run_num = _run._id
    output_dir = os.path.join(MODEL_DIR, str(run_num))

    model.train()
    torch.set_grad_enabled(True)
    optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch_num in range(num_epochs):
        batch_losses = []
        data_loader.dataset.on_epoch_start()
        t_iter = tqdm.tqdm(data_loader)
        for batch in t_iter:
            x0 = batch["x"]
           
     
            if model_type == "conditional":
                y = batch["y"].to(DEVICE)
            
            batch_kwargs = {
                key : item.to(DEVICE) for key, item in batch.items()
                if key not in ("x", "y")
            }
      
            x0 = x0.to(DEVICE).float()
            
            # Sample random times from 0 to t_limit
            t = (torch.rand(x0.shape[0]) * t_limit).to(DEVICE)
    
            # Run SDE forward to get xt and the true score at xt
            xt, true_score = sde.forward(x0, t)
            
            # Get model-predicted score
            if model_type == "conditional":
                class_inds = class_mapper(y).long()
                pred_score = model(xt, t, class_inds, **batch_kwargs)
            else:
                pred_score = model(xt, t, **batch_kwargs)

            # Get weighting factor
            if loss_weighting_type == "ml":
                loss_weight = 1 / sde.diff_coef_func(xt, t)
            elif loss_weighting_type == "expected_norm":
                loss_weight = sde._inflate_dims(sde.mean_score_mag(t))
            elif loss_weighting_type == "empirical_norm":
                loss_weight = sde._inflate_dims(torch.mean(
                    torch.square(true_score), dim=tuple(range(1, len(x0.shape)))
                ))
            elif loss_weighting_type is None:
                loss_weight = torch.ones_like(x0)

            if weight_func is not None:
                # Division here, as `loss_weight` itself is the divisor
                extra_weights = weight_func(x0)
                loss_weight = loss_weight / extra_weights

            # Compute loss
            loss = model.loss(
                pred_score, true_score, loss_weight, **batch_kwargs
            )
            loss_val = loss.item()
            t_iter.set_description("Loss: %.2f" % loss_val)

            if not np.isfinite(loss_val):
                continue

            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optim.step()
            
            batch_losses.append(loss_val)
        
        epoch_loss = np.mean(batch_losses)
        print("Epoch %d average Loss: %.2f" % (epoch_num + 1, epoch_loss))

        _run.log_scalar("train_epoch_loss", epoch_loss)
        _run.log_scalar("train_batch_losses", batch_losses)

        model_path = os.path.join(
            output_dir, "epoch_%d_ckpt.pth" % (epoch_num + 1)
        )
        link_path = os.path.join(output_dir, "last_ckpt.pth")
        
        # Save model
        util.save_model(model, model_path)

        # Create symlink to last epoch
        if os.path.islink(link_path):
            os.remove(link_path)
        os.symlink(os.path.basename(model_path), link_path)


@train_ex.command
def train_conditional_model(
    model, sde, data_loader, num_epochs, learning_rate, _run,
    class_to_class_index, loss_weighting_type="empirical_norm", t_limit=1, DEVICE = "cuda"
):
    """
    Wrapper for `train_model`.
    """
    train_model(
        model, sde, data_loader, "conditional", num_epochs, learning_rate, _run,
        class_to_class_index, loss_weighting_type=loss_weighting_type,
        t_limit=t_limit, DEVICE = DEVICE 
    )


@train_ex.command
def train_unconditional_model(
    model, sde, data_loader, num_epochs, learning_rate, _run,
    loss_weighting_type="empirical_norm", t_limit=1, DEVICE = "cuda"
):
    """
    Wrapper for `train_model`.
    """
    train_model(
        model, sde, data_loader, "unconditional", num_epochs, learning_rate,
        _run, loss_weighting_type=loss_weighting_type, t_limit=t_limit, DEVICE = DEVICE
    )
