import re
from sys import gettrace

from Modules.dn3.trainable.utils import _make_mask, _make_span_from_seeds
from Modules.dn3.utils import LabelSmoothedCrossEntropyLoss
from Modules.dn3.trainable.models import Classifier
from Modules.dn3.transforms.batch import BatchTransform

# Swap these two for Ipython/Jupyter
# import tqdm
# import tqdm.notebook as tqdm
import tqdm.auto as tqdm

import torch
# ugh the worst, why did they make this protected...
from torch.optim.lr_scheduler import _LRScheduler as Scheduler

import numpy as np
from pandas import DataFrame
from collections import OrderedDict
from torch.utils.data import DataLoader, WeightedRandomSampler


class BaseProcess(object):
 """
 Initialization of the Base Trainable object. Any learning procedure that leverages DN3atasets should subclass
 this base class.

 By default uses the SGD with momentum optimization.
 """

 def __init__(self, lr=0.001, metrics=None, evaluation_only_metrics=None, l2_weight_decay=0.01, cuda=None, **kwargs):
 """
 Initialization of the Base Trainable object. Any learning procedure that leverages DN3atasets should subclass
 this base class.

 By default uses the SGD with momentum optimization.

 Parameters
 ----------
 cuda : bool, string, None
 If boolean, sets whether to enable training on the GPU, if a string, specifies can be used to specify
 which device to use. If None (default) figures it out automatically.
 lr : float
 The learning rate to use, this will probably something that should be tuned for each application.
 Start with multiplying or dividing by values of 2, 5 or 10 to seek out a good number.
 metrics : dict, list
 A dictionary of named (keys) metrics (values) or some iterable set of metrics that will be identified
 by their class names.
 evaluation_only_metrics : list
 A list of names of metrics that will be used for evaluation only (not calculated or
 reported during training steps).
 l2_weight_decay : float
 One of the simplest and most common regularizing techniques. If you find a model rapidly
 reaching high training accuracy (and not validation) increase this. If having trouble fitting
 the training data, decrease this.
 kwargs : dict
 Arguments that will be used by the processes' :py:meth:`BaseProcess.build_network()` method.
 """
 if cuda is None:
 cuda = torch.cuda.is_available()
 if cuda:
 tqdm.tqdm.write("GPU(s) detected: training and model execution will be performed on GPU.")
 if isinstance(cuda, bool):
 cuda = "cuda" if cuda else "cpu"
 assert isinstance(cuda, str)
 self.cuda = cuda
 self.device = torch.device(cuda)
 self._eval_metrics = list() if evaluation_only_metrics is None else list(evaluation_only_metrics).copy()
 self.metrics = OrderedDict()
 if metrics is not None:
 if isinstance(metrics, (list, tuple)):
 metrics = {m.__class__.__name__: m for m in metrics}
 if isinstance(metrics, dict):
 self.add_metrics(metrics)

 _before_members = set(self.__dict__.keys())
 self.build_network(**kwargs)
 new_members = set(self.__dict__.keys()).difference(_before_members)
 self._training = False
 self._trainables = list()
 for member in new_members:
 if isinstance(self.__dict__[member], (torch.nn.Module, torch.Tensor)):
 if not (isinstance(self.__dict__[member], torch.Tensor) and not self.__dict__[member].requires_grad):
 self._trainables.append(member)
 self.__dict__[member] = self.__dict__[member].to(self.device)

 self.optimizer = torch.optim.SGD(self.parameters(), weight_decay=l2_weight_decay, lr=lr, nesterov=True,
 momentum=0.9)
 self.scheduler = None
 self.scheduler_after_batch = False
 self.epoch = None
 self.lr = lr
 self.weight_decay = l2_weight_decay

 self._batch_transforms = list()
 self._eval_transforms = list()

 def set_optimizer(self, optimizer):
 assert isinstance(optimizer, torch.optim.Optimizer)
 del self.optimizer
 self.optimizer = optimizer
 self.lr = float(self.optimizer.param_groups[0]['lr'])

 def set_scheduler(self, scheduler, step_every_batch=False):
 """
 This allow the addition of a learning rate schedule to the process. By default, a linear warmup with cosine
 decay will be used. Any scheduler that is an instance of :any:`Scheduler` (pytorch's schedulers, or extensions
 thereof) can be set here. Additionally, a string keywords can be used including:
 - "constant"

 Parameters
 ----------
 scheduler: str, Scheduler
 step_every_batch: bool
 Whether to call step after every batch (if `True`), or after every epoch (`False`)

 """
 if isinstance(scheduler, str):
 if scheduler.lower() == 'constant':
 scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda e: 1.0)
 else:
 raise ValueError("Scheduler {} is not supported.".format(scheduler))
 # This is the most common one that needs this, force this to be true
 elif isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR):
 self.scheduler_after_batch = True
 else:
 self.scheduler_after_batch = step_every_batch
 self.scheduler = scheduler

 def add_metrics(self, metrics: dict, evaluation_only=False):
 self.metrics.update(**metrics)
 if evaluation_only:
 self._eval_metrics += list(metrics.keys())

 def _optimize_dataloader_kwargs(self, num_worker_cap=6, **loader_kwargs):
 loader_kwargs.setdefault('pin_memory', self.cuda == 'cuda')
 # Use multiple worker processes when NOT DEBUGGING
 if gettrace() is None:
 try:
 # Find number of cpus available (taken from second answer):
 #
 m = re.search(r'(?m)^Cpus_allowed:\s*(.*)$',
 open('/proc/self/status').read())
 nw = bin(int(m.group(1).replace(',', ''), 16)).count('1')
 # Cap the number of workers at 6 (actually 4) to avoid pummeling disks too hard
 nw = min(num_worker_cap, nw)
 except FileNotFoundError:
 # Fallback for when proc/self/status does not exist
 nw = 2
 else:
 # 0 workers means not extra processes are spun up
 nw = 2
 loader_kwargs.setdefault('num_workers', int(nw - 2))
 print("Loading data with {} additional workers".format(loader_kwargs['num_workers']))
 return loader_kwargs

 def _get_batch(self, iterator):
 batch = [x.to(self.device, non_blocking=self.cuda == 'cuda') for x in next(iterator)]
 xforms = self._batch_transforms if self._training else self._eval_transforms
 for xform in xforms:
 if xform.only_trial_data:
 batch[0] = xform(batch[0])
 else:
 batch = xform(batch)
 return batch

 def add_batch_transform(self, transform: BatchTransform, training_only=True):
 self._batch_transforms.append(transform)
 if not training_only:
 self._eval_transforms.append(transform)

 def clear_batch_transforms(self):
 self._batch_transforms = list()
 self._eval_transforms = list()

 def build_network(self, **kwargs):
 """
 This method is used to add trainable modules to the process. Rather than placing objects for training
 in the __init__ method, they should be placed here.

 By default any arguments that propagate unused from __init__ are included here.
 """
 self.__dict__.update(**kwargs)

 def parameters(self):
 """
 All the trainable parameters in the Trainable. This includes any architecture parameters and meta-parameters.

 Returns
 -------
 params :
 An iterator of parameters
 """
 for member in self._trainables:
 yield from self.__dict__[member].parameters()

 def forward(self, *inputs):
 """
 Given a batch of inputs, return the outputs produced by the trainable module.

 Parameters
 ----------
 inputs :
 Tensors needed for underlying module.

 Returns
 -------
 outputs :
 Outputs of module

 """
 raise NotImplementedError

 def calculate_loss(self, inputs, outputs):
 """
 Given the inputs to and outputs from underlying modules, calculate the loss.

 Returns
 -------
 Loss :
 Single loss quantity to be minimized.
 """
 if isinstance(outputs, (tuple, list)):
 device = outputs[0].device
 else:
 device = outputs.device
 loss_fn = self.loss
 if hasattr(self.loss, 'to'):
 loss_fn = loss_fn.to(device)
 return loss_fn(outputs, inputs[-1])

 def calculate_metrics(self, inputs, outputs):
 """
 Given the inputs to and outputs from the underlying module. Return tracked metrics.

 Parameters
 ----------
 inputs :
 Input tensors.
 outputs :
 Output tensors.

 Returns
 -------
 metrics : OrderedDict, None
 Dictionary of metric quantities.
 """
 metrics = OrderedDict()
 for met_name, met_fn in self.metrics.items():
 if self._training and met_name in self._eval_metrics:
 continue
 try:
 metrics[met_name] = met_fn(inputs, outputs)
 # I know its super broad, but basically if metrics fail during training, I want to just ignore them...
 except:
 continue
 return metrics

 def backward(self, loss):
 self.optimizer.zero_grad()
 loss.backward()

 def train(self, mode=True):
 self._training = mode
 for member in self._trainables:
 self.__dict__[member].train(mode=mode)

 def train_step(self, *inputs):
 self.train(True)
 outputs = self.forward(*inputs)
 loss = self.calculate_loss(inputs, outputs)
 self.backward(loss)

 self.optimizer.step()
 if self.scheduler is not None and self.scheduler_after_batch:
 self.scheduler.step()

 train_metrics = self.calculate_metrics(inputs, outputs)
 train_metrics.setdefault('loss', loss.item())

 return train_metrics

 def evaluate(self, dataset, **loader_kwargs):
 """
 Calculate and return metrics for a dataset

 Parameters
 ----------
 dataset: DN3ataset, DataLoader
 The dataset that will be used for evaluation, if not a DataLoader, one will be constructed
 loader_kwargs: dict
 Args that will be passed to the dataloader, but `shuffle` and `drop_last` will be both be
 forced to `False`

 Returns
 -------
 metrics : OrderedDict
 Metric scores for the entire
 """
 self.train(False)
 inputs, outputs = self.predict(dataset, **loader_kwargs)
 metrics = self.calculate_metrics(inputs, outputs)
 metrics['loss'] = self.calculate_loss(inputs, outputs).item()
 return metrics

 def predict(self, dataset, **loader_kwargs):
 """
 Determine the outputs for all loaded data from the dataset

 Parameters
 ----------
 dataset: DN3ataset, DataLoader
 The dataset that will be used for evaluation, if not a DataLoader, one will be constructed
 loader_kwargs: dict
 Args that will be passed to the dataloader, but `shuffle` and `drop_last` will be both be
 forced to `False`

 Returns
 -------
 inputs : Tensor
 The exact inputs used to calculate the outputs (in case they were stochastic and need saving)
 outputs : Tensor
 The outputs from each run of :function:`forward`
 """
 self.train(False)
 loader_kwargs.setdefault('batch_size', 1)
 dataset = self._make_dataloader(dataset, **loader_kwargs)

 pbar = tqdm.trange(len(dataset), desc="Predicting")
 data_iterator = iter(dataset)

 inputs = list()
 outputs = list()

 with torch.no_grad():
 for iteration in pbar:
 input_batch = self._get_batch(data_iterator)
 output_batch = self.forward(*input_batch)

 inputs.append([tensor.cpu() for tensor in input_batch])
 if isinstance(output_batch, torch.Tensor):
 outputs.append(output_batch.cpu())
 else:
 outputs.append([tensor.cpu() for tensor in output_batch])

 def package_multiple_tensors(batches: list):
 if isinstance(batches[0], torch.Tensor):
 return torch.cat(batches)
 elif isinstance(batches[0], (tuple, list)):
 return [torch.cat(b) for b in zip(*batches)]

 return package_multiple_tensors(inputs), package_multiple_tensors(outputs)

 @classmethod
 def standard_logging(cls, metrics: dict, start_message="End of Epoch"):
 if start_message.rstrip()[-1] != '|':
 start_message = start_message.rstrip() + " |"
 for m in metrics:
 if 'acc' in m.lower() or 'pct' in m.lower():
 start_message += " {}: {:.2%} |".format(m, metrics[m])
 elif m == 'lr':
 start_message += " {}: {:.3e} |".format(m, metrics[m])
 else:
 start_message += " {}: {:.3f} |".format(m, metrics[m])
 tqdm.tqdm.write(start_message)

 def save_best(self):
 """
 Create a snapshot of what is being currently trained for re-laoding with the :py:meth:`load_best()` method.

 Returns
 -------
 best : Any
 Whatever format is needed for :py:meth:`load_best()`, will be the argument provided to it.
 """
 return [{k: v.cpu() for k, v in self.__dict__[m].state_dict().items()} for m in self._trainables]

 def load_best(self, best):
 """
 Load the parameters as saved by :py:meth:`save_best()`.

 Parameters
 ----------
 best: Any
 """
 for m, state_dict in zip(self._trainables, best):
 self.__dict__[m].load_state_dict({k: v.to(self.device) for k, v in state_dict.items()})

 def _retain_best(self, old_checkpoint, metrics_to_check: dict, retain_string: str):
 if retain_string is None:
 return old_checkpoint
 best_checkpoint = old_checkpoint

 def found_best():
 tqdm.tqdm.write("Best {}. Retaining checkpoint...".format(retain_string))
 self.best_metric = metrics_to_check[retain_string]
 return self.save_best()

 if retain_string not in metrics_to_check.keys():
 tqdm.tqdm.write("No metric {} found in recorded metrics. Not saving best.")
 if self.best_metric is None:
 best_checkpoint = found_best()
 elif retain_string == 'loss' and metrics_to_check[retain_string] <= self.best_metric:
 best_checkpoint = found_best()
 elif retain_string != 'loss' and metrics_to_check[retain_string] >= self.best_metric:
 best_checkpoint = found_best()

 return best_checkpoint

 @staticmethod
 def _dataloader_args(dataset, training=False, **loader_kwargs):
 # Only shuffle and drop last when training
 loader_kwargs.setdefault('shuffle', training)
 loader_kwargs.setdefault('drop_last', training)

 return loader_kwargs

 def _make_dataloader(self, dataset, training=False, **loader_kwargs):
 """Any args that make more sense as a convenience function to be set"""
 if isinstance(dataset, DataLoader):
 return dataset

 return DataLoader(dataset, **self._dataloader_args(dataset, training, **loader_kwargs))

 def fit(self, training_dataset, epochs=1, validation_dataset=None, step_callback=None,
 resume_epoch=None, resume_iteration=None, log_callback=None, validation_callback=None,
 epoch_callback=None, batch_size=8, warmup_frac=0.2, retain_best='loss',
 validation_interval=None, train_log_interval=None, **loader_kwargs):
 """
 sklearn/keras-like convenience method to simply proceed with training across multiple epochs of the provided
 dataset

 Parameters
 ----------
 training_dataset : DN3ataset, DataLoader
 validation_dataset : DN3ataset, DataLoader
 epochs : int
 Total number of epochs to fit
 resume_epoch : int
 The starting epoch to train from. This will likely only be used to resume training at a certain
 point.
 resume_iteration : int
 Similar to start epoch but specified in batches. This can either be used alone, or in
 conjunction with `start_epoch`. If used alone, the start epoch is the floor of
 `start_iteration` divided by batches per epoch. In other words this specifies cumulative
 batches if start_epoch is not specified, and relative to the current epoch otherwise.
 step_callback : callable
 Function to run after every training step that has signature: fn(train_metrics) -> None
 log_callback : callable
 Function to run after every log interval that has signature: fn(train_metrics) -> None
 validation_callback : callable
 Function to run after every time the validation dataset is run through. This typically has the
 result of this and the `epoch_callback` called at the end of the epoch, but this is also called
 after `validation_interval` batches.
 This callback has the signature: fn(validation_metrics) -> None
 epoch_callback : callable
 Function to run after every epoch that has signature: fn(validation_metrics) -> None
 batch_size : int
 The batch_size to be used for the training and validation datasets. This is ignored if they are
 provided as `DataLoader`.
 warmup_frac : float
 The fraction of iterations that will be spent *increasing* the learning rate under the default
 1cycle policy (with cosine annealing). Value will be automatically clamped values between [0, 0.5]
 retain_best : (str, None)
 **If `validation_dataset` is provided**, which model weights to retain. If 'loss' (default), will
 retain the model at the epoch with the lowest validation loss. If another string, will assume that
 is the metric to monitor for the *highest score*. If None, the final model is used.
 validation_interval: int, None
 The number of batches between checking the validation dataset
 train_log_interval: int, None
 The number of batches between persistent logging of training metrics, if None (default) happens
 at the end of every epoch.
 loader_kwargs :
 Any remaining keyword arguments will be passed as such to any DataLoaders that are automatically
 constructed. If both training and validation datasets are provided as `DataLoaders`, this will be
 ignored.

 Notes
 -----
 If the datasets above are provided as DN3atasets, automatic optimizations are performed to speed up loading.
 These include setting the number of workers = to the number of CPUs/system threads - 1, and pinning memory for
 rapid CUDA transfer if leveraging the GPU. Unless you are very comfortable with PyTorch, it's probably better
 to not provide your own DataLoader, and let this be done automatically.

 Returns
 -------
 train_log : Dataframe
 Metrics after each iteration of training as a pandas dataframe
 validation_log : Dataframe
 Validation metrics after each epoch of training as a pandas dataframe
 """
 loader_kwargs.setdefault('batch_size', batch_size)
 loader_kwargs = self._optimize_dataloader_kwargs(**loader_kwargs)
 training_dataset = self._make_dataloader(training_dataset, training=True, **loader_kwargs)

 if resume_epoch is None:
 if resume_iteration is None or resume_iteration < len(training_dataset):
 resume_epoch = 1
 else:
 resume_epoch = resume_iteration // len(training_dataset)
 resume_iteration = 1 if resume_iteration is None else resume_iteration % len(training_dataset)

 _clear_scheduler_after = self.scheduler is None
 if _clear_scheduler_after:
 last_epoch_workaround = len(training_dataset) * (resume_epoch - 1) + resume_iteration
 last_epoch_workaround = -1 if last_epoch_workaround <= 1 else last_epoch_workaround
 self.set_scheduler(
 torch.optim.lr_scheduler.OneCycleLR(self.optimizer, self.lr, epochs=epochs,
 steps_per_epoch=len(training_dataset),
 pct_start=warmup_frac,
 last_epoch=last_epoch_workaround)
 )

 validation_log = list()
 train_log = list()
 self.best_metric = None
 best_model = self.save_best()

 train_log_interval = len(training_dataset) if train_log_interval is None else train_log_interval
 metrics = OrderedDict()

 def update_metrics(new_metrics: dict, iterations):
 if len(metrics) == 0:
 return metrics.update(new_metrics)
 else:
 for m in new_metrics:
 try:
 metrics[m] = (metrics[m] * (iterations - 1) + new_metrics[m]) / iterations
 except KeyError:
 metrics[m] = new_metrics[m]

 def print_training_metrics(epoch, iteration=None):
 if iteration is not None:
 self.standard_logging(metrics, "Training: Epoch {} - Iteration {}".format(epoch, iteration))
 else:
 self.standard_logging(metrics, "Training: End of Epoch {}".format(epoch))

 def _validation(epoch, iteration=None):
 _metrics = self.evaluate(validation_dataset, **loader_kwargs)
 if iteration is not None:
 self.standard_logging(_metrics, "Validation: Epoch {} - Iteration {}".format(epoch, iteration))
 else:
 self.standard_logging(_metrics, "Validation: End of Epoch {}".format(epoch))
 _metrics['epoch'] = epoch
 validation_log.append(_metrics)
 if callable(validation_callback):
 validation_callback(_metrics)
 return _metrics

 epoch_bar = tqdm.trange(resume_epoch, epochs + 1, desc="Epoch", unit='epoch', initial=resume_epoch, total=epochs)
 for epoch in epoch_bar:
 self.epoch = epoch
 pbar = tqdm.trange(resume_iteration, len(training_dataset) + 1, desc="Iteration", unit='batches',
 initial=resume_iteration, total=len(training_dataset))
 data_iterator = iter(training_dataset)
 for iteration in pbar:
 inputs = self._get_batch(data_iterator)
 train_metrics = self.train_step(*inputs)
 train_metrics['lr'] = self.optimizer.param_groups[0]['lr']
 if 'momentum' in self.optimizer.defaults:
 train_metrics['momentum'] = self.optimizer.param_groups[0]['momentum']
 update_metrics(train_metrics, iteration+1)
 pbar.set_postfix(metrics)
 train_metrics['epoch'] = epoch
 train_metrics['iteration'] = iteration
 train_log.append(train_metrics)
 if callable(step_callback):
 step_callback(train_metrics)

 if iteration % train_log_interval == 0 and pbar.total != iteration:
 print_training_metrics(epoch, iteration)
 train_metrics['epoch'] = epoch
 train_metrics['iteration'] = iteration
 if callable(log_callback):
 log_callback(metrics)
 metrics = OrderedDict()

 if isinstance(validation_interval, int) and (iteration % validation_interval == 0)\
 and validation_dataset is not None:
 _m = _validation(epoch, iteration)
 best_model = self._retain_best(best_model, _m, retain_best)

 # Make epoch summary
 metrics = DataFrame(train_log)
 metrics = metrics[metrics['epoch'] == epoch]
 metrics = metrics.mean().to_dict()
 metrics.pop('iteration', None)
 print_training_metrics(epoch)

 if validation_dataset is not None:
 metrics = _validation(epoch)
 best_model = self._retain_best(best_model, metrics, retain_best)

 if callable(epoch_callback):
 epoch_callback(metrics)
 metrics = OrderedDict()
 # All future epochs should not start offset in iterations
 resume_iteration = 1

 if not self.scheduler_after_batch and self.scheduler is not None:
 tqdm.tqdm.write(f"Step {self.scheduler.get_last_lr()} {self.scheduler.last_epoch}")
 self.scheduler.step()

 if _clear_scheduler_after:
 self.set_scheduler(None)
 self.epoch = None

 if retain_best is not None and validation_dataset is not None:
 tqdm.tqdm.write("Loading best model...")
 self.load_best(best_model)

 return DataFrame(train_log), DataFrame(validation_log)


class StandardClassification(BaseProcess):

 def __init__(self, classifier: torch.nn.Module, loss_fn=None, cuda=None, metrics=None, learning_rate=0.01,
 label_smoothing=None, **kwargs):
 if isinstance(metrics, dict):
 metrics.setdefault('Accuracy', self._simple_accuracy)
 else:
 metrics = dict(Accuracy=self._simple_accuracy)
 super(StandardClassification, self).__init__(cuda=cuda, lr=learning_rate, classifier=classifier,
 metrics=metrics, **kwargs)
 if label_smoothing is not None and isinstance(label_smoothing, float) and (0 < label_smoothing < 1):
 self.loss = LabelSmoothedCrossEntropyLoss(self.classifier.targets, smoothing=label_smoothing).\
 to(self.device)
 elif loss_fn is None:
 self.loss = torch.nn.CrossEntropyLoss().to(self.device)
 else:
 self.loss = loss_fn.to(self.device)
 self.best_metric = None

 @staticmethod
 def _simple_accuracy(inputs, outputs):
 if isinstance(outputs, (list, tuple)):
 outputs = outputs[0]
 # average over last dimensions
 while len(outputs.shape) >= 3:
 outputs = outputs.mean(dim=-1)
 return (inputs[-1] == outputs.argmax(dim=-1)).float().mean().item()

 def forward(self, *inputs):
 if isinstance(self.classifier, Classifier) and self.classifier.return_features:
 prediction, _ = self.classifier(*inputs[:-1])
 else:
 prediction = self.classifier(*inputs[:-1])
 return prediction

 def calculate_loss(self, inputs, outputs):
 inputs = list(inputs)

 def expand_for_strided_loss(factors):
 inputs[-1] = inputs[-1].unsqueeze(-1).expand(-1, *factors)

 check_me = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
 if len(check_me.shape) >= 3:
 expand_for_strided_loss(check_me.shape[2:])

 return super(StandardClassification, self).calculate_loss(inputs, outputs)

 def fit(self, training_dataset, epochs=1, validation_dataset=None, step_callback=None, epoch_callback=None,
 batch_size=8, warmup_frac=0.2, retain_best='loss', balance_method=None, **loader_kwargs):
 """
 sklearn/keras-like convenience method to simply proceed with training across multiple epochs of the provided
 dataset

 Parameters
 ----------
 training_dataset : DN3ataset, DataLoader
 validation_dataset : DN3ataset, DataLoader
 epochs : int
 step_callback : callable
 Function to run after every training step that has signature: fn(train_metrics) -> None
 epoch_callback : callable
 Function to run after every epoch that has signature: fn(validation_metrics) -> None
 batch_size : int
 The batch_size to be used for the training and validation datasets. This is ignored if they are
 provided as `DataLoader`.
 warmup_frac : float
 The fraction of iterations that will be spent *increasing* the learning rate under the default
 1cycle policy (with cosine annealing). Value will be automatically clamped values between [0, 0.5]
 retain_best : (str, None)
 **If `validation_dataset` is provided**, which model weights to retain. If 'loss' (default), will
 retain the model at the epoch with the lowest validation loss. If another string, will assume that
 is the metric to monitor for the *highest score*. If None, the final model is used.
 balance_method : (None, str)
 If and how to balance training samples when training. `None` (default) will simply randomly
 sample all training samples equally. 'undersample' will sample each class N_min times
 where N_min is equal to the number of examples in the minority class. 'oversample' will sample
 each class N_max times, where N_max is the number of the majority class.
 loader_kwargs :
 Any remaining keyword arguments will be passed as such to any DataLoaders that are automatically
 constructed. If both training and validation datasets are provided as `DataLoaders`, this will be
 ignored.

 Notes
 -----
 If the datasets above are provided as DN3atasets, automatic optimizations are performed to speed up loading.
 These include setting the number of workers = to the number of CPUs/system threads - 1, and pinning memory for
 rapid CUDA transfer if leveraging the GPU. Unless you are very comfortable with PyTorch, it's probably better
 to not provide your own DataLoader, and let this be done automatically.

 Returns
 -------
 train_log : Dataframe
 Metrics after each iteration of training as a pandas dataframe
 validation_log : Dataframe
 Validation metrics after each epoch of training as a pandas dataframe
 """
 return super(StandardClassification, self).fit(training_dataset, epochs=epochs, step_callback=step_callback,
 epoch_callback=epoch_callback, batch_size=batch_size,
 warmup_frac=warmup_frac, retain_best=retain_best,
 validation_dataset=validation_dataset,
 balance_method=balance_method,
 **loader_kwargs)

 BALANCE_METHODS = ['undersample', 'oversample', 'ldam']
 def _make_dataloader(self, dataset, training=False, **loader_kwargs):
 if isinstance(dataset, DataLoader):
 return dataset

 loader_kwargs = self._dataloader_args(dataset, training=training, **loader_kwargs)

 if training and loader_kwargs.get('sampler', None) is None and loader_kwargs.get('balance_method', None) \
 is not None:
 method = loader_kwargs.pop('balance_method')
 assert method.lower() in self.BALANCE_METHODS
 if not hasattr(dataset, 'get_targets'):
 print("Failed to create dataloader with {} balancing. {} does not support `get_targets()`.".format(
 method, dataset
 ))
 elif method.lower() != 'ldam':
 sampler = balanced_undersampling(dataset) if method.lower() == 'undersample' \
 else balanced_oversampling(dataset)
 # Shuffle is implied by the balanced sampling
 # loader_kwargs['shuffle'] = None
 loader_kwargs['sampler'] = sampler
 else:
 self.loss = create_ldam_loss(dataset)

 if loader_kwargs.get('sampler', None) is not None:
 loader_kwargs['shuffle'] = None

 # Make sure balance method is not passed to DataLoader at this point.
 loader_kwargs.pop('balance_method', None)
 loader_kwargs.pop('_d', None)
 print(loader_kwargs)
 return DataLoader(dataset, **loader_kwargs)


def get_label_balance(dataset):
 """
 Given a dataset, return the proportion of each target class and the counts of each class type

 Parameters
 ----------
 dataset

 Returns
 -------
 sample_weights, counts
 """
 assert hasattr(dataset, 'get_targets')
 labels = dataset.get_targets()
 counts = np.bincount(labels)
 train_weights = 1. / torch.tensor(counts, dtype=torch.float)
 sample_weights = train_weights[labels]
 class_freq = counts/counts.sum()
 if len(counts) < 10:
 tqdm.tqdm.write('Class frequency: {}'.format(' | '.join('{:.2f}'.format(c) for c in class_freq)))
 else:
 tqdm.tqdm.write("Class frequencies range from {:.2e} to {:.2e}".format(class_freq.min(), class_freq.max()))
 return sample_weights, counts


def balanced_undersampling(dataset, replacement=False):
 tqdm.tqdm.write("Undersampling for balanced distribution.")
 sample_weights, counts = get_label_balance(dataset)
 return WeightedRandomSampler(sample_weights, len(counts) * int(counts.min()), replacement=replacement)


def balanced_oversampling(dataset, replacement=True):
 tqdm.tqdm.write("Oversampling for balanced distribution.")
 sample_weights, counts = get_label_balance(dataset)
 return WeightedRandomSampler(sample_weights, len(counts) * int(counts.max()), replacement=replacement)


class LDAMLoss(torch.nn.Module):
 # September 2020 - Originally taken from:
 # October 2020 - Modified to support non-cuda devices and a switch to activate drw

 def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30):
 super(LDAMLoss, self).__init__()
 self._cls_nums = cls_num_list
 m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list))
 m_list = m_list * (max_m / np.max(m_list))
 m_list = torch.FloatTensor(m_list)
 self.m_list = m_list
 assert s > 0
 self.s = s
 self.weight = weight

 def _determine_drw_weights(self, beta=0.9999):
 effective_num = 1.0 - np.power(beta, self._cls_nums)
 per_cls_weights = (1.0 - beta) / np.array(effective_num)
 return torch.from_numpy(per_cls_weights / np.sum(per_cls_weights) * len(self._cls_nums)).float()

 def drw(self, on=True, beta=0.9999):
 self.weight = self._determine_drw_weights(beta=beta) if on else None

 def forward(self, x, target):
 index = torch.zeros_like(x, dtype=torch.uint8)
 index.scatter_(1, target.data.view(-1, 1), 1)

 index_float = index.float()
 batch_m = torch.matmul(self.m_list[None, :].to(index.device), index_float.transpose(0, 1))
 batch_m = batch_m.view((-1, 1))
 x_m = x - batch_m

 output = torch.where(index, x_m, x)
 w = self.weight.to(index.device) if self.weight is not None else None
 return torch.nn.functional.cross_entropy(self.s * output, target, weight=w)


def create_ldam_loss(training_dataset):
 sample_weights, counts = get_label_balance(training_dataset)
 return LDAMLoss(counts)


def _make_span_from_seeds(seeds, span, total=None):
 inds = list()
 for seed in seeds:
 for i in range(seed, seed + span):
 if total is not None and i >= total:
 break
 elif i not in inds:
 inds.append(int(i))
 return np.array(inds)


class BendingCollegeWav2Vec(BaseProcess):
 """
 A more wav2vec 2.0 style of constrastive self-supervision, more inspired-by than exactly like it.
 """
 def __init__(self, encoder, context_fn, mask_rate=0.1, mask_span=6, learning_rate=0.01, temp=0.5,
 permuted_encodings=False, permuted_contexts=False, enc_feat_l2=0.001, multi_gpu=False,
 l2_weight_decay=1e-4, unmasked_negative_frac=0.25, encoder_grad_frac=1.0,
 num_negatives=100, **kwargs):
 self.predict_length = mask_span
 self._enc_downsample = encoder.downsampling_factor
 if multi_gpu:
 encoder = torch.nn.DataParallel(encoder)
 context_fn = torch.nn.DataParallel(context_fn)
 if encoder_grad_frac < 1:
 encoder.register_backward_hook(lambda module, in_grad, out_grad:
 tuple(encoder_grad_frac * ig for ig in in_grad))
 super(BendingCollegeWav2Vec, self).__init__(encoder=encoder, context_fn=context_fn,
 loss_fn=torch.nn.CrossEntropyLoss(), lr=learning_rate,
 l2_weight_decay=l2_weight_decay,
 metrics=dict(Accuracy=self._contrastive_accuracy,
 Mask_pct=self._mask_pct), **kwargs)
 self.best_metric = None
 self.mask_rate = mask_rate
 self.mask_span = mask_span
 self.temp = temp
 self.permuted_encodings = permuted_encodings
 self.permuted_contexts = permuted_contexts
 self.beta = enc_feat_l2
 self.start_token = getattr(context_fn, 'start_token', None)
 self.unmasked_negative_frac = unmasked_negative_frac
 self.num_negatives = num_negatives

 def description(self, sequence_len):
 encoded_samples = self._enc_downsample(sequence_len)
 desc = "{} samples | mask span of {} at a rate of {} => E[masked] ~= {}".format(
 encoded_samples, self.mask_span, self.mask_rate,
 int(encoded_samples * self.mask_rate * self.mask_span))
 return desc

 def _generate_negatives(self, z):
 """Generate negative samples to compare each sequence location against"""
 batch_size, feat, full_len = z.shape
 z_k = z.permute([0, 2, 1]).reshape(-1, feat)
 negative_inds = torch.empty(batch_size, full_len, self.num_negatives).long()
 ind_weights = torch.ones(full_len, full_len) - torch.eye(full_len)
 with torch.no_grad():
 # candidates = torch.arange(full_len).unsqueeze(-1).expand(-1, self.num_negatives).flatten()
 for i in range(batch_size):
 negative_inds[i] = torch.multinomial(ind_weights, self.num_negatives) + i*full_len
 # From wav2vec 2.0 implementation, I don't understand
 # negative_inds[negative_inds >= candidates] += 1

 z_k = z_k[negative_inds.view(-1)].view(batch_size, full_len, self.num_negatives, feat)
 return z_k, negative_inds

 def _calculate_similarity(self, z, c, negatives):
 c = c[..., 1:].permute([0, 2, 1]).unsqueeze(-2)
 z = z.permute([0, 2, 1]).unsqueeze(-2)

 # In case the contextualizer matches exactly, need to avoid divide by zero errors
 negative_in_target = (c == negatives).all(-1)
 targets = torch.cat([z, negatives], dim=-2)

 logits = torch.nn.functional.cosine_similarity(c, targets, dim=-1) / self.temp
 if negative_in_target.any():
 logits[1:][negative_in_target] = float("-inf")

 return logits.view(-1, logits.shape[-1])

 def forward(self, *inputs):
 z = self.encoder(inputs[0])

 if self.permuted_encodings:
 z = z.permute([1, 2, 0])

 unmasked_z = z.clone()

 batch_size, feat, samples = z.shape

 if self._training:
 mask = _make_mask((batch_size, samples), self.mask_rate, samples, self.mask_span)
 else:
 mask = torch.zeros((batch_size, samples), requires_grad=False, dtype=torch.bool)
 half_avg_num_seeds = max(1, int(samples * self.mask_rate * 0.5))
 if samples <= self.mask_span * half_avg_num_seeds:
 raise ValueError("Masking the entire span, pointless.")
 mask[:, _make_span_from_seeds((samples // half_avg_num_seeds) * np.arange(half_avg_num_seeds).astype(int),
 self.mask_span)] = True

 c = self.context_fn(z, mask)

 # Select negative candidates and generate labels for which are correct labels
 negatives, negative_inds = self._generate_negatives(unmasked_z)

 # Prediction -> batch_size x predict_length x predict_length
 logits = self._calculate_similarity(unmasked_z, c, negatives)
 return logits, unmasked_z, mask, c

 @staticmethod
 def _mask_pct(inputs, outputs):
 return outputs[2].float().mean().item()

 @staticmethod
 def _contrastive_accuracy(inputs, outputs):
 logits = outputs[0]
 labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
 return StandardClassification._simple_accuracy([labels], logits)

 def calculate_loss(self, inputs, outputs):
 logits = outputs[0]
 # The 0'th index is the correct position
 labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)

 # Note that loss_fn here integrates the softmax as per the normal classification pipeline (leveraging logsumexp)
 return self.loss_fn(logits, labels) + self.beta * outputs[1].pow(2).mean()


class BendingCollegeClassification(BendingCollegeWav2Vec, StandardClassification):

 def __init__(self, bendr_model, mask_rate=0.1, mask_span=6, learning_rate=0.01, temp=0.5,
 permuted_encodings=False, permuted_contexts=False, enc_feat_l2=0.001, multi_gpu=False,
 l2_weight_decay=1e-4, unmasked_negative_frac=0.25, encoder_grad_frac=1.0,
 num_negatives=100, max_reconstruction_loss_frac=0.2, **kwargs):
 StandardClassification.__init__(self, bendr_model.classifier,
 metrics={
 'Accuracy': lambda i, o: self._simple_accuracy(i, o[-1]),
 'Contrast-Accuracy':self._contrastive_accuracy,
 'Mask-pct':self._mask_pct
 },
 encoder=bendr_model.encoder,
 context_fn=bendr_model.contextualizer)
 if isinstance(bendr_model.encoder, torch.nn.DataParallel):
 encoder = bendr_model.encoder.module
 contextualizer = bendr_model.contextualizer.module
 else:
 encoder = bendr_model.encoder
 contextualizer = bendr_model.contextualizer

 self.predict_length = mask_span
 self._enc_downsample = encoder.downsampling_factor
 if encoder_grad_frac < 1:
 encoder.register_backward_hook(lambda module, in_grad, out_grad:
 tuple(encoder_grad_frac * ig for ig in in_grad))
 self.best_metric = None
 self.mask_rate = mask_rate
 self.mask_span = mask_span
 self.temp = temp
 self.permuted_encodings = permuted_encodings
 self.permuted_contexts = permuted_contexts
 self.beta = enc_feat_l2
 self.start_token = getattr(contextualizer, 'start_token', None)
 self.unmasked_negative_frac = unmasked_negative_frac
 self.num_negatives = num_negatives
 self.r_lambda = max_reconstruction_loss_frac

 def forward(self, *inputs):
 logits, unmasked_z, mask, c = BendingCollegeWav2Vec.forward(self, *inputs)
 prediction = self.classifier(c[..., 0])
 return logits, unmasked_z, mask, prediction

 def calculate_loss(self, inputs, outputs):
 logits = outputs[0]
 # The 0'th index is the correct position
 correct_idx = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)

 mlm_loss = self.loss(logits, correct_idx) + self.beta * outputs[1].pow(2).mean()
 cls_loss = StandardClassification.calculate_loss(self, inputs, outputs[-1])

 return (1 - self.r_lambda) * cls_loss + self.r_lambda * mlm_loss

