import warnings
import higher
import torch
import tqdm
import copy

# For muting the scheduler warning.
warnings.simplefilter("ignore")


def unrolled_differentiation(meta_model, meta_optimizer, base_model, base_optimizer, training, validation,
                            gradient_steps, inner_gradient_steps, batch_size, task_loss_fn, performance_metric,
                             verbose, device, offline=False):

    # Creating a PyTorch dataloader object for generating batches.
    training = torch.utils.data.DataLoader(training, batch_size=batch_size, shuffle=True)
    validation = torch.utils.data.DataLoader(validation, batch_size=batch_size, shuffle=True)

    # List for keeping track of the learning history.
    training_history, meta_model_history = [], []

    if not offline:  # Only take one inner step if loss function is being adapted *online*.
        inner_gradient_steps = 1  # Overriding the given hyperparameter to ensure constraint.

    # Performing the offline initialization phase to learn the learned loss functions parameters (phi).
    for step in tqdm.tqdm(range(gradient_steps), desc="Offline Progression" if offline else "Online Progression",
                          position=0, dynamic_ncols=True, disable=False if verbose >= 1 else True, leave=False):

        if offline:  # If offline initialization we mode need to reset (resync) weights.
            base_model.reset()  # Resetting the weights of the base model.

        # Clearing the gradient cache.
        meta_optimizer.zero_grad()

        # Creating a differentiable optimizer and stateless models via PyTorch higher.
        with higher.innerloop_ctx(base_model, base_optimizer, copy_initial_weights=False) as (fmodel, diffopt):

            # Caching the meta loss function every 1000 gradients steps
            if step % 100 == 0 if offline else step % 1000 == 0:
                meta_model_history.append(copy.deepcopy(meta_model))

            # Taking a predetermined number of inner steps before meta update.
            for inner_steps in range(inner_gradient_steps):

                # Extracting a new training batch from the current task.
                X_train, y_train = next(iter(training))
                X_train, y_train = X_train.to(device), y_train.to(device)

                # Computing the loss using the learned loss and updating the base weights.
                yp_train = fmodel(X_train)  # Computing the base network predictions.
                base_loss = meta_model(yp_train, y_train)  # Finding the loss wrt. learned-loss.
                diffopt.step(base_loss)  # Update base network weights (theta).

            if not offline:  # If online synchronize the base model parameters with the patched modules weights.
                base_model.load_state_dict(fmodel.state_dict())

            if offline:  # If initializing use training set.
                X_valid, y_valid = X_train, y_train
            else:  # Else sample a validation batch from the current task.
                X_valid, y_valid = next(iter(validation))
                X_valid, y_valid = X_valid.to(device), y_valid.to(device)

            # Computing predictions on the validation sets.
            yp_valid = fmodel(X_valid)  # Predictions with new weights on the validation set.

            # Computing the task loss and updating the meta weights.
            task_loss = task_loss_fn(yp_valid, y_valid)  # Finding the loss wrt. meta (task) loss.
            task_loss.backward()  # Accumulates gradients wrt. to meta parameters.
            meta_optimizer.step()  # Update meta-loss network weights (phi).

            # Storing the training performance history.
            training_history.append(performance_metric(yp_train, y_train).item())

    return meta_model, base_model, training_history, meta_model_history
