#!/usr/bin/env python3

import os
from argparse import ArgumentParser
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import sys
sys.path.append('./src')
from utils import *

import numpy as np
import torch
from torchvision import models, transforms
from lucent.optvis import objectives, render, param

_default_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def activation_optimization(
        model,
        params,
        output_folder,

        optim_obj = None,
        maintain_obj = None,
        vis_obj = None,

        callback = None,
        image_side_length = 224,
        nsteps = 100,
        save_interval = 10,

        alpha = 0.01,
        optimizer = torch.optim.Adam,
        device = _default_device,
        use_tqdm = True,
        save_image = True):

    optional_tqdm = lambda iterable: tqdm(iterable) if use_tqdm else iterable

    if optim_obj is None:
        optim_obj = lambda alpha: torch.zeros(1).to(device)

    if maintain_obj is None:
        maintain_obj = lambda alpha: torch.zeros(1).to(device)

    if callback is None:
        callback = lambda activations: torch.zeros(1).to(device)

    set_model_activations(model, require_grad=True)
    optimizer = optimizer(params)

    # Initialize outputs
    optimal_images = []
    optim_objective_values = []
    callback_outputs = []
    maintain_objective_values = []

    for i in optional_tqdm(range(nsteps)):

        # Zero out the gradients
        model.zero_grad()
        optimizer.zero_grad()
        
        # Run the optim objective, extract the scalar value
        optim_objective_value = optim_obj(rate = alpha)

        # Run the maintain objective, extract the scalar value
        maintain_objective_value = maintain_obj(rate = 1 - alpha)

        # Sum the optim and maintain objective values
        objective_value = alpha * optim_objective_value + (1-alpha) * maintain_objective_value

        # Here we are assumign that the optim and maintain
        # objectives have already done their .backward() calls
        optimizer.step()

        if i % save_interval == 0:
            
            # Save parameters
            for name, par in model.named_parameters():
                torch.save(par, os.path.join(output_folder, "parameter_checkpoints", f"{i}.{name}.pt"))

            # Save optimal image
            param_f = lambda: param.image(image_side_length)
            optimal_image = render.render_vis(model, vis_obj, param_f, save_image=save_image, image_name=os.path.join(output_folder, "optimal_image_checkpoints", f"{i}.png"), progress=False)
            optimal_images.append(optimal_image)
            
            # Save objective values
            optim_objective_values.append(optim_objective_value)
            maintain_objective_values.append(maintain_objective_value)
            print(f"(Step {i}) Optim objective:", optim_objective_values[-1])
            print(f"(Step {i}) Maintain objective:", maintain_objective_values[-1])
            print(f"(Step {i}) Objective:", objective_value)

            # Save the output of the callback function
            callback_outputs.append(callback(model.activations).detach().cpu().numpy())
            print(f"(Step {i}) Callback:", callback_outputs[-1])
   
    return optimal_images, optim_objective_values, callback_outputs, maintain_objective_values

            
