from torch_dreams import Dreamer
import torch.nn as nn
import os
from torch_dreams.batched_objective import BatchedObjective
from torch_dreams.batched_image_param import BatchedAutoImageParam
from tqdm import tqdm


def get_chunks(lst, chunk_size):
    # looping till length l
    for i in range(0, len(lst), chunk_size):
        yield lst[i : i + chunk_size]


class ChannelObjective:
    def __init__(self, layer_index, channel_index, flipped: bool = False):
        self.layer_index = layer_index
        self.channel_index = channel_index
        self.flipped = flipped

    def __call__(self, layer_outputs):
        loss = layer_outputs[self.layer_index][:, self.channel_index, :, :].sum()
        if self.flipped is not True:
            return -loss
        else:
            return loss


class NeuronObjective:
    def __init__(self, layer_index: int, neuron_index: int, flipped: bool = False):
        self.layer_index = layer_index
        self.neuron_index = neuron_index
        self.flipped = flipped

    def __call__(self, layer_outputs):
        loss = layer_outputs[self.layer_index][:, self.neuron_index].sum()
        if self.flipped is not True:
            return -loss
        else:
            return loss


class ConvLayerFeaturevisGenerator:
    def __init__(
        self,
        model: nn.Module,
        target_layer: nn.Module,
        batch_size: int,
        render_kwargs: dict,
        width: int = 224,
        height: int = 224,
        standard_deviation: float = 0.01,
        device="cuda:0",
        quiet=False,
    ):
        self.dreamer = Dreamer(model=model, device=device, quiet=quiet)

        all_objectives = [
            ChannelObjective(layer_index=0, channel_index=i, flipped=False)
            for i in range(target_layer.out_channels)
        ]

        self.objective_chunks = list(get_chunks(all_objectives, chunk_size=batch_size))
        self.index_chunks = list(
            get_chunks([i for i in range(len(all_objectives))], chunk_size=batch_size)
        )
        self.target_layer = target_layer
        self.render_kwargs = render_kwargs

        self.width = width
        self.height = height
        self.standard_deviation = standard_deviation
        self.batch_size = batch_size
        self.device = device

    def generate(self, output_folder: str):
        assert os.path.exists(
            output_folder
        ), f"Expected output_folder to exist: {output_folder}"

        progress = tqdm(
            total=len(self.index_chunks),
            desc=f"Generating feature visualizations with batch size {self.batch_size}:",
        )
        filenames = []
        for batch_indices, objective_batch in zip(
            self.index_chunks, self.objective_chunks
        ):
            f = os.path.join(output_folder, f"{batch_indices[-1]}.png")
            if not (os.path.exists(f)):
                batched_objective = BatchedObjective(objectives=objective_batch)

                bap = BatchedAutoImageParam(
                    batch_size=len(objective_batch),
                    width=self.width,
                    height=self.height,
                    standard_deviation=self.standard_deviation,
                    device=self.device,
                )

                image_param = self.dreamer.render(
                    image_parameter=bap,
                    layers=[self.target_layer],
                    custom_func=batched_objective,
                    **self.render_kwargs,
                )

                for batch_index, neuron_index in enumerate(batch_indices):
                    filename = os.path.join(output_folder, f"{neuron_index}.png")
                    image_param[batch_index].save(filename)
                    # print(f"saved: {filename}")
                    filenames.append(filename)

            else:
                for batch_index, neuron_index in enumerate(batch_indices):
                    filename = os.path.join(output_folder, f"{neuron_index}.png")
                    filenames.append(filename)

            progress.update(1)

        return filenames


class LinearLayerFeaturevisGenerator:
    def __init__(
        self,
        model: nn.Module,
        target_layer: nn.Module,
        batch_size: int,
        render_kwargs: dict,
        width: int = 224,
        height: int = 224,
        standard_deviation: float = 0.01,
        device="cuda:0",
        quiet=False,
    ):
        self.dreamer = Dreamer(model=model, device=device, quiet=quiet)

        all_objectives = [
            NeuronObjective(layer_index=0, neuron_index=i, flipped=False)
            for i in range(target_layer.out_features)
        ]

        self.objective_chunks = list(get_chunks(all_objectives, chunk_size=batch_size))
        self.index_chunks = list(
            get_chunks([i for i in range(len(all_objectives))], chunk_size=batch_size)
        )
        self.target_layer = target_layer
        self.render_kwargs = render_kwargs

        self.width = width
        self.height = height
        self.standard_deviation = standard_deviation
        self.batch_size = batch_size
        self.device = device

    def generate(self, output_folder: str):
        assert os.path.exists(
            output_folder
        ), f"Expected output_folder to exist: {output_folder}"

        progress = tqdm(
            total=len(self.index_chunks),
            desc=f"Generating feature visualizations with batch size {self.batch_size}:",
        )
        filenames = []
        for batch_indices, objective_batch in zip(
            self.index_chunks, self.objective_chunks
        ):
            f = os.path.join(output_folder, f"{batch_indices[-1]}.png")
            if not (os.path.exists(f)):
                batched_objective = BatchedObjective(objectives=objective_batch)

                bap = BatchedAutoImageParam(
                    batch_size=len(objective_batch),
                    width=self.width,
                    height=self.height,
                    standard_deviation=self.standard_deviation,
                    device=self.device,
                )

                image_param = self.dreamer.render(
                    image_parameter=bap,
                    layers=[self.target_layer],
                    custom_func=batched_objective,
                    **self.render_kwargs,
                )

                for batch_index, neuron_index in enumerate(batch_indices):
                    filename = os.path.join(output_folder, f"{neuron_index}.png")
                    image_param[batch_index].save(filename)
                    # print(f"saved: {filename}")
                    filenames.append(filename)

            else:
                for batch_index, neuron_index in enumerate(batch_indices):
                    filename = os.path.join(output_folder, f"{neuron_index}.png")
                    filenames.append(filename)

            progress.update(1)

        return filenames
