#usage
#python3 run_train.py +exp=spec2vec ++exp.runner.device=cuda ++exp.runner.multi_gpu=False ++exp.runner.num_workers=0 +data=timestamp_data +model=debug_finetune_model ++exp.task.name=debug_finetune_task ++exp.criterion.name=debug_finetune_criterion ++exp.runner.total_steps=1000 ++model.frozen_upstream=True ++exp.runner.checkpoint_step=-1
import logging
import numpy as np
import models
from torch.utils import data
import torch
from tasks import register_task
from tasks.base_task import BaseTask
from tasks.batch_utils import baseline_wav_collator
from util.tensorboard_utils import plot_tensorboard_line
from sklearn.metrics import roc_auc_score, f1_score
from datasets import build_dataset
import random

log = logging.getLogger(__name__)

@register_task(name="nn_decoding_task")
class NNDecodingTask(BaseTask):
    def __init__(self, cfg):
        super(NNDecodingTask, self).__init__(cfg)

    def load_datasets(self, data_cfg, cached_seeg_data=None, cached_word_df=None, train_indices=None, val_indices=None, test_indices=None):
        #create train/val/test dataset
        all_dataset = build_dataset(data_cfg, task_cfg=self.cfg, cached_seeg_data=cached_seeg_data, cached_word_df=cached_word_df)
        assert not (train_indices is None or val_indices is None or test_indices is None)
        self.train_set = torch.utils.data.Subset(all_dataset, train_indices)
        self.valid_set = torch.utils.data.Subset(all_dataset, val_indices)
        self.test_set = torch.utils.data.Subset(all_dataset, test_indices)
        self.input_dim = all_dataset.get_input_dim()

    def build_model(self, cfg):
        #assert hasattr(self, "dataset")
        return models.build_model(cfg, input_dim=self.input_dim)
        
    @classmethod
    def setup_task(cls, cfg):
        return cls(cfg)

    def get_valid_outs(self, model, valid_loader, criterion, device):
        model.eval()
        all_outs = {"loss":0}
        predicts, labels = [], []
        with torch.no_grad():
            for batch in valid_loader:
                batch["input"] = batch["input"].to(device)
                _, valid_outs = criterion(model, batch, device, return_predicts=True)

                predicts.append(valid_outs["predicts"])
                labels.append(batch["labels"])
                all_outs["loss"] += valid_outs["loss"]
        labels = np.array([x for y in labels for x in y])
        predicts = [np.array([p]) if len(p.shape)==0 else p for p in predicts]
        predicts = np.concatenate(predicts)
        roc_auc = roc_auc_score(labels, predicts)
        all_outs["loss"] /= len(valid_loader)
        all_outs["roc_auc"] = roc_auc
        f1 = f1_score(labels, np.round(predicts))
        all_outs["f1"] = f1
        return all_outs

    def get_batch_iterator(self, dataset, batch_size, shuffle=True, **kwargs):
        return data.DataLoader(dataset, batch_size=batch_size, collate_fn=baseline_wav_collator, **kwargs)

    def output_logs(self, train_logging_outs, val_logging_outs, writer, global_step):
        pass

