import os
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import pytorch_lightning as pl
import torch
from sklearn import metrics
from torch import Tensor, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from puupl.lib.architectures import get_architecture
from puupl.lib.data import get_dataset
from puupl.lib.losses import ECELoss, get_loss
from puupl.lib.postprocessing import TemperatureScaler, get_postprocessors
from puupl.lib.pseudolabel import get_pseudolabeler

TensorDict = Dict[str, Tensor]


class Experiment(pl.LightningModule):
    # pylint: disable=arguments-differ

    def __init__(self, config: Dict[str, Any], run_name: str, experiment_name: str) -> None:
        super().__init__()

        self.save_hyperparameters()  # needed for pl's checkpoint loading

        self.architecture = get_architecture(config['architecture']).float()
        self.loss = get_loss(config['loss'])
        self.pseudolabeler = get_pseudolabeler(config.get('pseudolabeler'))
        self.params = config['exp_params']
        self.run_name = run_name
        self.experiment_name = experiment_name
        self.epoch = 0
        self.train_data, self.val_data, self.test_data = get_dataset(config['dataset'])
        self.should_reset_weights = False
        self._has_pseudolabels = False
        self._best_baseline_score: Optional[float] = None
        self._best_pl_score: Optional[float] = None

        self.post_processors = get_postprocessors(
            config['exp_params'].get('post_processing')
        )
        if config['exp_params'].get('temperature_scale'):  # just backward compatibility
            self.post_processors.append(TemperatureScaler())

    # --------------------------------------------------------------------------------------
    # --- PSEUDO-LABELING LOGIC

    def on_train_epoch_start(self) -> None:
        tqdm.write(f'starting epoch {self.epoch}')

        # pseudo label step, updated the dataset of the train_gen
        if (
            self.pseudolabeler is None or self.params['pseudolabel_every'] is None
            or self.epoch % self.params['pseudolabel_every'] != 0
            or self.epoch < self.params.get('warmup', 0) + self.params['pseudolabel_every']
        ):
            self.train_data.initialize_and_enable_mixup()
            return

        self.train_data.disable_mixup()
        with torch.no_grad():
            # apply post-processing steps
            yhat_train = self._apply_post_processors(self.architecture.sample_predictions(
                self.train_dataloader(shuffle=False)
            ))

            # compute predictions and uncertainty
            yhat = yhat_train.sigmoid().float()
            unc = self._compute_uncertainty(yhat)
            preds = yhat.mean(dim=0) if len(yhat.shape) > 1 else yhat

        # do the pseudo-(un)labeling
        self.pseudolabeler.pseudolabel(
            self.train_gen.dataset.y,  # type: ignore
            preds, unc,
            self.train_gen.dataset.p,  # type: ignore
            self.train_gen.dataset.l,  # type: ignore
        )

        self._has_pseudolabels = torch.any(self.train_gen.dataset.l).item()  # type: ignore

        # track amount of pseudo-labels
        self.logger.experiment.log_metric(
            key='pseudo_labels',
            value=int(torch.sum(self.train_gen.dataset.l)),  # type: ignore
            step=self.epoch,
            run_id=self.logger.run_id
        )

        # re-initialize the model weights and optimizer state after pseudo-labeling
        if self.params.get('reset_weights_after_pl', True):
            self.architecture.reset_model()
            for opt, opt_state in zip(self.optims, self.optims_initial_states):
                opt.load_state_dict(opt_state)
            for sched, sched_state in zip(self.scheds, self.scheds_initial_states):
                sched.load_state_dict(sched_state)
            self.should_reset_weights = False

        # update sample weights
        if self.params.get('sample_weights', 'uniform') == 'uniform':
            pass
        elif self.params['sample_weights'] == 'uncertainty':
            self.train_gen.dataset.w = unc  # type: ignore
        else:
            raise ValueError('unknown sample weighting scheme')

        self.train_data.initialize_and_enable_mixup()

    def _apply_post_processors(self, yhat_train: Tensor) -> Tensor:
        if not self.post_processors:
            return yhat_train

        if self.params.get('fit_postprocessors_on_labeled_validation', True):
            # use true labels of the validation set to fit the post-processors
            yhats = self.architecture.sample_predictions(self.val_gen)
            ys = self.val_gen.dataset.y  # type: ignore
            ts = self.val_gen.dataset.t.float()  # type: ignore
        else:
            # use positives and assigned pseudo-labels to fit the post-processors
            mask = (
                self.train_gen.dataset.p | self.train_gen.dataset.l  # type: ignore
            )

            # note: we pretend that the pseudo-labels are the true labels
            ts = ys = self.train_gen.dataset.y[mask]  # type: ignore

            if not torch.any(ys < 0.5):
                tqdm.write('No negatives to fit post-processors, skipping...')
                return yhat_train

            yhats = self.architecture.sample_predictions(
                self.train_gen, 'cpu'
            )[..., mask]

        yhs = yhats.mean(dim=0) if len(yhats.shape) > 1 else yhats
        for proc in self.post_processors:
            proc.fit(yhs, ys, ts)
            yhat_train = proc.scale(yhat_train)

        return yhat_train

    def _compute_uncertainty(self, yhat: Tensor) -> Tensor:
        # samples on rows, repetitions on columns, mean not necessary for a single rep
        mean = lambda x: torch.mean(x, dim=0) if len(x.shape) > 1 else x
        entropy = lambda x: -x * torch.log(x) - (1 - x) * torch.log(1 - x)

        # following arxiv:1910.09457 eq. 21, 22 and 23
        if 'uncertainty_type' not in self.params:
            raise ValueError('uncertainty type must be defined for pseudo-labeling')
        if self.params['uncertainty_type'] in ('total', 'entropy_of_mean', 'eom'):
            unc = entropy(mean(yhat))
        elif self.params['uncertainty_type'] in ('aleatoric', 'mean_of_entropy', 'moe'):
            unc = mean(entropy(yhat))
        elif self.params['uncertainty_type'] in ('epistemic', 'eom_moe'):
            unc = entropy(mean(yhat)) - mean(entropy(yhat))
            if len(yhat.shape) == 1:
                raise RuntimeError('cannot compute total uncertainty with a single sample')
        elif self.params['uncertainty_type'] in ('none', 'predictions'):
            unc = 0.5 - torch.abs(mean(yhat) - 0.5)
        else:
            raise ValueError(f'unknown uncertainty type {self.params["uncertainty_type"]}')

        unc = torch.nan_to_num(unc, 0.0)  # type: ignore
        assert torch.all(unc >= -5e-6), \
            f'uncertainty must be non-negative, minimim: {unc.min().item()}'
        return unc

    # --------------------------------------------------------------------------------------
    # --- TRAINING AND VALIDATION LOGIC

    def forward(self, x: Tensor) -> Tensor:  # type: ignore
        preds = self.architecture(x)
        if len(preds.shape) > 1:
            # we assume samples are on the last dimension and average out
            # all other dimensions to get per-sample predictions
            preds = torch.mean(preds, dim=list(range(0, len(preds.shape) - 1)))
        return preds

    def _step(self, batch: Dict[str, Tensor], batch_idx: int) -> TensorDict:
        phat = self.forward(batch['x'])
        res = self.loss(phat, **batch)
        res['phat'] = phat
        for k, v in batch.items():
            if k != 'x':
                res[k] = v
        return res

    training_step = _step  # type: ignore

    validation_step = _step  # type: ignore

    test_step = _step  # type: ignore

    def training_epoch_end(self, outputs: List[TensorDict]) -> None:  # type: ignore
        self.evaluate(outputs, part='train')

    def validation_epoch_end(self, outputs: List[TensorDict]) -> None:  # type: ignore
        def _update_best(current_scores: Dict[str, float], best_score: Optional[float]
                         ) -> Tuple[bool, Optional[float]]:
            """
            If necessary, updates the best score with the current validation scores,
            where the score is either the accuracy on the true labels or the nnPU loss.
            Return whether the score was updated and the new best score.
            """

            validate_on_true_labels = self.params.get('validate_on_true_labels', True)
            if validate_on_true_labels:
                new_score = current_scores['val_accuracy']
            else:
                new_score = current_scores['val_loss']
                if scores['val_pu_correction']:
                    # if we use the nnPU loss as criterion, do not update scores
                    # when the non-negativity correction was applied
                    return False, best_score

            if best_score is None:
                return True, new_score

            if validate_on_true_labels:
                improved = new_score > best_score
            else:
                improved = new_score < best_score

            return improved, new_score if improved else best_score

        scores = self.evaluate(outputs, part='val')

        if self.epoch > 0:
            if self._has_pseudolabels:
                evaluate_on_test, self._best_pl_score = _update_best(
                    scores, self._best_pl_score
                )
                if evaluate_on_test:
                    tqdm.write(f'New best PL validation score at epoch {self.epoch}')
                if self._best_pl_score is not None:
                    self.logger.experiment.log_metric(
                        key='best_pl_validation_score',
                        value=float(self._best_pl_score),
                        step=self.epoch, run_id=self.logger.run_id
                    )
            else:
                evaluate_on_test, self._best_baseline_score = _update_best(
                    scores, self._best_baseline_score
                )
                if evaluate_on_test:
                    tqdm.write(f'New best baseline score at epoch {self.epoch}')
                if self._best_baseline_score is not None:
                    self.logger.experiment.log_metric(
                        key='best_baseline_validation_score',
                        value=float(self._best_baseline_score),
                        step=self.epoch, run_id=self.logger.run_id
                    )

            if evaluate_on_test:
                self.architecture.eval()
                outputs = [
                    {k: v.cpu() for k, v in self.test_step({  # type: ignore
                        k: v.to(self.architecture.device)
                        for k, v in batch.items()
                    }, i).items()}
                    for i, batch in enumerate(self.test_dataloader())
                ]
                self.evaluate(outputs, part='test')
                self.architecture.train()

        # track epochs as required by mlflow stepwise logger
        self.epoch += 1

    def test_epoch_end(self, outputs: Sequence[Dict[str, Tensor]]) -> None:  # type: ignore
        self.evaluate(outputs, part='test')

        self.logger.experiment.log_param(
            key='run_name', value=self.run_name, run_id=self.logger.run_id)
        self.logger.experiment.log_param(
            key='experiment_name', value=self.experiment_name, run_id=self.logger.run_id)

        # save accuracy for hyperband
        hb_outdir = self.params.get('hyperband_output_dir')
        if hb_outdir is not None:
            with open(os.path.join(hb_outdir, 'result'), 'w') as f:
                score = self._best_pl_score
                if score is None:
                    score = float('nan')
                elif self.params.get('validate_on_true_labels', True):
                    score = -1 * score  # negative accuracy as hyperband minimizes
                f.write(str(score))

    def evaluate(self, outputs: Sequence[TensorDict], part: str) -> Dict[str, float]:
        def get_metrics(labels: Tensor, logits: Tensor) -> Dict[str, float]:
            return {
                'accuracy': metrics.accuracy_score(labels, logits > 0.0),
                'auc': metrics.roc_auc_score(labels, logits),
                'aps': metrics.average_precision_score(labels, logits),
                'f1': metrics.f1_score(labels, logits > 0.0),
                'ece': ECELoss(order=1).forward(logits, labels).item(),
                'brier': ECELoss(order=2).forward(logits, labels).item(),
            }

        tls: Dict[str, List[Tensor]] = {}
        for out in outputs:
            for k, v in out.items():
                if k not in tls:
                    tls[k] = []
                tls[k].append(v.cpu() if v.shape else v.view(1).cpu())
        ts = {k: torch.cat(v).detach() for k, v in tls.items()}

        if part == 'train':
            results = {
                # evaluate on pseudo-labels
                **{f'{part}_pseudo_{k}': v for k, v in get_metrics(
                    ts['y'] >= 0.5, ts['phat']).items()},

                # evaluate on true labels
                **{f'{part}_true_{k}': v for k, v in get_metrics(
                    ts['t'] >= 0.5, ts['phat']).items()},
            }
        else:
            # evaluate on true labels
            results = {f'{part}_{k}': v for k, v in get_metrics(
                ts['t'] >= 0.5, ts['phat']).items()}

        # log the currently used loss
        for key, value in self.loss(ts['phat'], ts['y'], ts['p'], ts['l'], ts['w']).items():
            results[f'{part}_{key}'] = value.item()

        if self.epoch > 0:
            # skip lightning's initial test run
            tqdm.write(f'{part} Metrics: ' + ' - '.join(
                f'{k.replace(part, "")[1:]}: {v:.4f}' for k, v in results.items()
            ))

            for key, value in results.items():
                self.logger.experiment.log_metric(
                    key=key, value=float(value), step=self.epoch, run_id=self.logger.run_id
                )

        return results

    # --------------------------------------------------------------------------------------
    # --- CONFIGURATION LOGIC

    def configure_optimizers(self) -> Union[
            Sequence[optim.Optimizer],
            Tuple[Sequence[optim.Optimizer],
                  Sequence[optim.lr_scheduler._LRScheduler]]
    ]:
        self.optims = [optim.Adam(
            self.architecture.parameters(),
            lr=self.params['learning_rate'],
            weight_decay=self.params['weight_decay']
        )]
        self.optims_initial_states = [opt.state_dict() for opt in self.optims]

        if self.params.get('scheduler_gamma') is not None:
            self.scheds = [optim.lr_scheduler.ExponentialLR(
                self.optims[0],
                gamma=self.params['scheduler_gamma']
            )]
            self.scheds_initial_states = [sched.state_dict() for sched in self.scheds]
            return self.optims, self.scheds
        else:
            self.scheds = []
            self.scheds_initial_states = []
            return self.optims

    def train_dataloader(self, shuffle: bool = True) -> DataLoader:
        self.train_gen = DataLoader(dataset=self.train_data,
                                    batch_size=self.params['batch_size'],
                                    collate_fn=self.train_data.supervised_collate_fn,
                                    num_workers=self.params.get('dataloader_workers', 0),
                                    shuffle=shuffle)
        return self.train_gen

    def val_dataloader(self) -> DataLoader:
        self.val_gen = DataLoader(dataset=self.val_data,
                                  batch_size=self.params['batch_size'],
                                  collate_fn=self.val_data.supervised_collate_fn,
                                  num_workers=self.params.get('dataloader_workers', 0),
                                  shuffle=False)
        return self.val_gen

    def test_dataloader(self) -> DataLoader:
        self.test_gen = DataLoader(dataset=self.test_data,
                                   batch_size=self.params['batch_size'],
                                   collate_fn=self.test_data.supervised_collate_fn,
                                   num_workers=self.params.get('dataloader_workers', 0),
                                   shuffle=False)
        return self.test_gen
