#!/usr/bin/env python3

from tqdm import tqdm
import sys
sys.path.insert(0, "/src")
from utils import *

import torch
from lucent.optvis import render, param

_default_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def activation_optimization(optim_obj,
                            vis_obj,
                            images,
                            model,
                            params,
                            nsteps,
                            save_interval,
                            image_side_length,
                            output_folder,
                            save_image=True,
                            optimizer=torch.optim.Adam,
                            callback=None, maintain_obj = None,
                            alpha=0.1, device=_default_device, use_tqdm = True, activations = None):

    optional_tqdm = lambda iterable: tqdm(iterable) if use_tqdm else iterable

    if callback is None:
        callback = lambda activations: torch.zeros(1).to(device)
    
    if maintain_obj is None:
        maintain_obj = lambda activations: torch.zeros(1).to(device)
    
    if activations is None:
        activations = get_model_activations(model, require_grad=True)
    optimizer = optimizer(params)
    
    # Initialize outputs
    optimal_images = []
    optim_objective_values = []
    callback_outputs = []
    maintain_objective_values = []

    for i in optional_tqdm(range(nsteps)):
        
        # Zero out the gradients
        model.zero_grad()
        optimizer.zero_grad()
        
        # Run the model over our selected images
        model(images)
        
        # Calculate the two objective functions
        optim_objective_value = optim_obj(activations)
        maintain_objective_value = maintain_obj(activations)
        
        # Combine them into the main objective function and backpropagate
        objective_value = alpha * optim_objective_value + (1 - alpha) * maintain_objective_value
        objective_value.backward()
        optimizer.step()

        if i % save_interval == 0:

            # Save parameters
            for name, par in model.named_parameters():
                torch.save(par, os.path.join(output_folder, "parameter_checkpoints", f"{i}.{name}.pt"))

            # Save optimal image
            param_f = lambda: param.image(image_side_length)
            optimal_image = render.render_vis(model, vis_obj, param_f, save_image=save_image, image_name=os.path.join(output_folder, "optimal_image_checkpoints", f"{i}.png"), progress=False)
            optimal_images.append(optimal_image)
            
            # Save objective values
            optim_objective_values.append(optim_objective_value.detach().cpu().numpy())
            maintain_objective_values.append(maintain_objective_value.detach().cpu().numpy())
            print(f"(Step {i}) Optim objective:", optim_objective_values[-1])
            print(f"(Step {i}) Maintain objective:", maintain_objective_values[-1])
            print(f"(Step {i}) Objective:", objective_value.detach().cpu().numpy())

            # Save the output of the callback function
            callback_outputs.append(callback(activations).detach().cpu().numpy())
            print(f"(Step {i}) Callback:", callback_outputs[-1])
   
    return optimal_images, optim_objective_values, callback_outputs, maintain_objective_values

