import os
from datetime import datetime
from typing import List

from pytz import timezone

tz = timezone("Europe/Berlin")

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter


class DummyClass(object):
    def do_nothing(*args, **kwargs):
        pass

    def __getattr__(self, _):
        return self.do_nothing


class RSquared:
    def __init__(self, normalized_labels: np.ndarray, device: str = 'cpu'):
        variance_per_factor = ((normalized_labels - normalized_labels.mean(
            axis=0, keepdims=True)) ** 2).mean(axis=0)
        self.variance_per_factor = torch.tensor(variance_per_factor).to(device)

    def __call__(self, predictions: torch.tensor,
                 targets: torch.tensor) -> torch.tensor:
        assert predictions.shape == targets.shape
        assert len(targets.shape) == 2
        mse_loss_per_factor = (predictions - targets).pow(2).mean(dim=0)
        return 1 - mse_loss_per_factor / self.variance_per_factor


def collect_per_factor(per_factors, epoch: int, name: str, mode: str,
                       factor_names: List[str], writer,
                       aggregate_fct=torch.mean):
    per_factor = torch.stack(per_factors, dim=0).mean(dim=0)
    for factor_i, factor_name in zip(per_factor, factor_names):
        writer.add_scalar(f'{mode}_{name}/{factor_name}', factor_i, epoch)
    writer.add_scalar(f'{mode}/{name}', aggregate_fct(per_factor), epoch)


def save_checkpoint(model, optimizer, args, epoch: int, save_folder: str,
                    name: str,
                    ):
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'args': args
    }

    with open(os.path.join(save_folder, name), 'wb') as f:
        torch.save(checkpoint, f)


class Tracker:
    """
    Convenience class to collect and track results
    """

    def __init__(self, writer: SummaryWriter):
        self.accumulated_infos = {}
        self.writer = writer

    def track(self, infos):
        for name, val in infos.items():
            d = self.accumulated_infos.get(name, [])
            d.append(val)
            self.accumulated_infos[name] = d

    def write(self, epoch, prefix='train/'):
        for name, val in self.accumulated_infos.items():
            self.writer.add_scalar(prefix + name, np.mean(val), epoch)
        self.accumulated_infos = {}


def get_exp_name(args_old, args_new, ignoring=('dataset', 'modification',
                                               'name', 'save_path',
                                               'model',
                                               'skip_dislib_eval', 'seed',
                                               'learning_rate',
                                               'writer')):
    """
    Returns a convenient experiment name for tensorboard that compares
    arguments given to argparse to the default settings. It then
    writes the arguments where the values differ from the
    default settings into the experiment name.
    """

    args_new = args_new.__dict__
    for key, val in args_new.items():
        if val == 'false' or val == 'False':
            args_new[key] = False
        if val == 'true' or val == 'True':
            args_new[key] = True

    exp_name = args_new['name'] + '_'
    for key in args_old:
        old_val = args_old[key]
        if old_val != args_new[key]:
            if key in ignoring:
                continue
            val = args_new[key]
            if isinstance(val, float):
                exp_name += f'{key[:15]}{val:.3f}-'
            elif isinstance(val, str):
                exp_name += f'{key[:15]}' + val[:5] + '-'
            else:
                exp_name += f'{key[:15]}' + str(val) + '-'

    tz = timezone("Europe/Berlin")

    return exp_name + f'--{datetime.now(tz=tz).strftime("%Y-%m-%d-%H-%M-%S")}'
