import torch
from tqdm import tqdm
import collections

from src.data.datasets.tensor_dict_dataset import TensorDictDataset


class ProbeTrainer(object):
    def update_feat_buffer(self, model_dict, **kwargs):
        if self.probe_specs is None:
            return

        # ____ Get all features required to train probes. ____
        probe_inputs = [(probe_dict["inputs"]) for probe_dict in self.probe_specs]
        probe_outputs = [(probe_dict["outputs"]) for probe_dict in self.probe_specs]
        buffer_feat_keys = list(set(probe_inputs)) + list(set(probe_outputs))

        # ____ Add other features that will be useful. ____
        for kwargs_key in kwargs.keys():
            if kwargs_key.endswith("ys_true"):
                buffer_feat_keys.append(kwargs_key)

        # ____ Add relevant features to the buffer. ____
        for k in buffer_feat_keys:
            k = k.split("/")[-1]  # If necessary, remove the prefix from the key.
            # Initialize the key in the feature buffer dict.
            if k not in self.feat_buffer_dict:
                self.feat_buffer_dict[k] = collections.deque([], maxlen=self.max_num_batches_in_buffer)

            # Add the features to the buffer dictionary.
            k_in_model_dict = k in model_dict.keys()  # If key in model_dict.
            k_in_kwargs = not all([not kwargs_key.endswith(k) for kwargs_key in kwargs.keys()])  # If key in kwargs.
            if k_in_model_dict:
                self.feat_buffer_dict[k].append(model_dict[k].clone().detach())
            elif k_in_kwargs:
                for kwargs_key in kwargs.keys():
                    if kwargs_key.endswith(k):
                        self.feat_buffer_dict[k].append(kwargs[kwargs_key].clone().detach())
            else:
                assert False, f"key {k} not found in {model_dict.keys()} or {kwargs.keys()}"

    def tensorize_feat_buffer(self):
        if self.probe_specs is None:
            return

        for k, v in self.feat_buffer_dict.items():
            self.feat_buffer_dict[k] = torch.cat(list(self.feat_buffer_dict[k]), dim=0).to(self.device)

    def train_probe_and_log(self):
        if self.probe_specs is None:
            return

        # Free GPU memory if possible.
        try:
            torch.cuda.empty_cache()
        except:
            print(f"Couldn't clear GPU cache for some reason (maybe not training on GPU). Moving on. ")

        # Initialize the probe networks and get the training loader.
        for probe_spec in self.probe_specs:
            probe_spec["model"] = probe_spec["model_init"]().to(self.device)

        # Set up the probe optimizer.
        probe_parameters = [probe_spec["model"].parameters() for probe_spec in self.probe_specs]
        probe_optimizer = self.probe_optimizer_init(*probe_parameters)

        # Form dataloader to be used for probe training.
        self.tensorize_feat_buffer()
        probe_dataset = TensorDictDataset(tensors_dict=self.feat_buffer_dict)
        probe_dataloader = self.probe_dataloader(dataset=probe_dataset)

        # Establish the logging backbone.
        track_for_n_batches = self.track_last_n_batch_probe_outputs
        for probe_spec in self.probe_specs:
            probe_spec["probe_losses"] = list()
            probe_spec["probe_losses_0"] = list()
            probe_spec["probe_losses_1"] = list()
            probe_spec["probe_preds_last_batches"] = collections.deque([], maxlen=track_for_n_batches)
            probe_spec["probe_targets_last_batches"] = collections.deque([], maxlen=track_for_n_batches)
            probe_spec["probe_idx0_last_batches"] = collections.deque([], maxlen=track_for_n_batches)
            probe_spec["probe_idx1_last_batches"] = collections.deque([], maxlen=track_for_n_batches)

    # Train.
        curr_batch_count = 0
        while curr_batch_count < self.max_probe_training_batches:
            for batch_idx, batch in tqdm(enumerate(probe_dataloader), total=self.max_probe_training_batches,
                                         desc="Training probes. "):
                # Update batch counter.
                curr_batch_count += 1

                # Perform forward and backward passes for all probes.
                total_probe_loss = 0.
                probe_optimizer.zero_grad()
                for probe_spec in self.probe_specs:
                    # Get the inputs and outputs for probe training.
                    probe_inputs = batch[probe_spec["inputs"]].to(self.device)
                    probe_targets = batch[probe_spec["outputs"]].to(self.device)

                    # Perform forward pass.
                    probe_preds = probe_spec["model"](probe_inputs)

                    # Compute loss and add it to total probe loss.
                    idx_0 = batch["ys_true"] == torch.zeros_like(batch["ys_true"])
                    idx_1 = batch["ys_true"] == torch.ones_like(batch["ys_true"])
                    probe_loss_0 = probe_spec["loss_fn"](probe_preds[idx_0], probe_targets[idx_0])
                    probe_loss_1 = probe_spec["loss_fn"](probe_preds[idx_1], probe_targets[idx_1])
                    probe_loss = probe_spec["loss_fn"](probe_preds, probe_targets)
                    total_probe_loss = total_probe_loss + probe_loss

                    # Log.
                    probe_spec["probe_losses"].append(float(probe_loss))
                    probe_spec["probe_losses_0"].append(float(probe_loss_0))
                    probe_spec["probe_losses_1"].append(float(probe_loss_1))
                    probe_spec["probe_preds_last_batches"].append(probe_preds)
                    probe_spec["probe_targets_last_batches"].append(probe_targets)
                    probe_spec["probe_idx0_last_batches"].append(idx_0)
                    probe_spec["probe_idx1_last_batches"].append(idx_1)

                # Backpropagate and update.
                total_probe_loss.backward()
                probe_optimizer.step()

                # Exit loop.
                if not (curr_batch_count < self.max_probe_training_batches):
                    break

        # Log.
        for probe_spec in self.probe_specs:
            for logging_fn in probe_spec.logging_fns:
                logging_fn(logger=self.logger, probe_spec=probe_spec, current_epoch=self.current_epoch,
                           global_step=self.global_step, game_step=self.game_step)

        # We're done training the probe. We can now clear the feature buffer.
        self.feat_buffer_dict = dict()
