import utils
from optimization_requirements import *
import copy
def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

def generate_artificial_top_images(title, model, output_folder, vis_obj_list, image_side_length,
                                   do_fft=True, save_image=True):
    if vis_obj_list is not None:
        print('Generating Artificial Optimal Images')
        print(f'save_image: {save_image}')

        param_f = lambda: param.image(image_side_length, fft=do_fft)

        images = []
        model_copy = copy.deepcopy(model)
        model_copy.eval()

        for channel_num, vis in enumerate(tqdm(vis_obj_list)):
            # copy model each time as we know gradients are messy with Lucent
            model.zero_grad()
            # if channel_num % 50 == 0:
            #     print(f'generating artificial image {channel_num} of {len(vis_obj_list)}')
            image = render.render_vis(model_copy, vis, param_f, save_image=save_image,
                                      show_image=False,
                                      image_name=os.path.join(output_folder,
                                                              f"channel_{channel_num}_{title}.png"),
                                      progress=False)
            # convert and permute due to how render generates the image which can then be vstacked easily
            images.append(torch.from_numpy(image[0]).permute([0, 3, 1, 2]))
        # Returns a [num_channels,3,224,224] tensor
        images = torch.vstack(images).unsqueeze(0)
        print(f'artificial image tensor shape: {images.shape}')
        return images
    
def prep_for_artificial_images(model, args_channel, args_layer):
    layers = get_model_layers(model)

    for step, channel in enumerate(args_channel):
        if channel == "a":
            args_channel.pop(step)
            layer = args_layer.pop(step)
            print('layer: ', layer)

            for j in range(layers[layer].out_channels):
                args_layer.append(layer)
                args_channel.append(j)
        else:
            args_channel[step] = int(channel)

    # channels_for_optimization = args.channel
    # Specify the regions that are being optimized
    channels = {layer: [] for layer in np.unique(args_layer)}
    for layer, channel in zip(args_layer, args_channel):
        channels[layer].append(channel)

    vis_obj_list = []
    for c in args_channel:
        vis_obj_list.append(objectives.channel(args_layer[c], args_channel[c]))
    return vis_obj_list, args_channel, args_layer

results_directories = [
    # "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f0_dataset_top10_to_zero",
    # "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f3_dataset_top10_to_zero_a01",
    # "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f6_dataset_top10_to_zero",
    # "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f8_dataset_top10_to_zero",
    "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_dataset_top10_to_zero",
    "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_dataset_refs_to_top",
]
layers = [
    # 'features_0',
    # 'features_3',
    # 'features_6',
    # 'features_8',
    'features_10',
    'features_10',
]
descriptors =[
    # 'conv1_push-down',
    # 'conv2_push-down',
    # 'conv3_push-down',
    # 'conv4_push-down',
    'conv5_push-down',
    'conv5_push-up',
]
if __name__ =='__main__':
    parser = ArgumentParser()
    parser.add_argument("--arch", default="alexnet", choices= ["alexnet", "resnet50-rob", "resnet18-rob", "efficientnet"])
    parser.add_argument("--output", type = str, default = 'artifical_images')
    parser.add_argument("--vis-obj", default='channel', choices=["center-neuron", "channel",])
    args = parser.parse_args()
    model_arch = args.arch
    output = 'artifical_images'
    ensure_dir(output)
   
    for results_directory, descriptor, layer in zip(results_directories, descriptors, layers):
        init_model, final_model = utils.get_init_final_models(results_directory, model_arch)

        vis_obj_list, args_channel, args_layer = prep_for_artificial_images(init_model, ["a"], [layer])

        generate_artificial_top_images(f'init_{descriptor}', init_model, output,  vis_obj_list, image_side_length=224,
                                   do_fft=True)
        
        vis_obj_list, args_channel, args_layer = prep_for_artificial_images(final_model, ["a"], [layer])

        generate_artificial_top_images(f'final_{descriptor}', init_model, output,  vis_obj_list, image_side_length=224,
                                   do_fft=True)




