import torch, time
import numpy as np
from algorithms.algorithm import Algorithm
from utils import update_average
from scheduler import step_scheduler
from wilds.common.utils import get_counts, numel


class GroupAlgorithm(Algorithm):
    """
    Parent class for algorithms with group-wise logging.
    Also handles schedulers.
    """

    def __init__(
        self,
        device,
        grouper,
        logged_metrics,
        logged_fields,
        schedulers,
        scheduler_metric_names,
        no_group_logging,
        **kwargs,
    ):
        """
        Args:
            - device: torch device
            - grouper (Grouper): defines groups for which we compute/log stats for
            - logged_metrics (list of Metric):
            - logged_fields (list of str):
        """
        super().__init__(device)
        self.grouper = grouper
        self.group_prefix = "group_"
        self.count_field = "count"
        self.group_count_field = f"{self.group_prefix}{self.count_field}"

        self.logged_metrics = logged_metrics
        self.logged_fields = logged_fields

        self.schedulers = schedulers
        self.scheduler_metric_names = scheduler_metric_names
        self.no_group_logging = no_group_logging

    def update_log(self, results):
        """
        Updates the internal log, Algorithm.log_dict
        Args:
            - results (dictionary)
        """
        results = self.sanitize_dict(results, to_out_device=False)
        # check all the fields exist
        for field in self.logged_fields:
            assert field in results, f"field {field} missing"
        # compute statistics for the current batch
        batch_log = {}
        with torch.no_grad():
            for m in self.logged_metrics:
                if not self.no_group_logging:
                    (
                        group_metrics,
                        group_counts,
                        worst_group_metric,
                    ) = m.compute_group_wise(
                        results["y_pred"],
                        results["y_true"],
                        results["g"],
                        self.grouper.n_groups,
                        return_dict=False,
                    )
                    batch_log[f"{self.group_prefix}{m.name}"] = group_metrics
                batch_log[m.agg_metric_field] = m.compute(
                    results["y_pred"], results["y_true"], return_dict=False
                ).item()
            count = numel(results["y_true"])

        # transfer other statistics in the results dictionary
        for field in self.logged_fields:
            if field.startswith(self.group_prefix) and self.no_group_logging:
                continue
            v = results[field]
            if isinstance(v, torch.Tensor) and v.numel() == 1:
                batch_log[field] = v.item()
            else:
                if isinstance(v, torch.Tensor):
                    assert (
                        v.numel() == self.grouper.n_groups
                    ), "Current implementation deals only with group-wise statistics or a single-number statistic"
                    assert field.startswith(self.group_prefix)
                batch_log[field] = v

        # update the log dict with the current batch
        if (
            not self._has_log
        ):  # since it is the first log entry, just save the current log
            self.log_dict = batch_log
            if not self.no_group_logging:
                self.log_dict[self.group_count_field] = group_counts
            self.log_dict[self.count_field] = count
        else:  # take a running average across batches otherwise
            for k, v in batch_log.items():
                if k.startswith(self.group_prefix):
                    if self.no_group_logging:
                        continue
                    self.log_dict[k] = update_average(
                        self.log_dict[k],
                        self.log_dict[self.group_count_field],
                        v,
                        group_counts,
                    )
                else:
                    self.log_dict[k] = update_average(
                        self.log_dict[k], self.log_dict[self.count_field], v, count
                    )
            if not self.no_group_logging:
                self.log_dict[self.group_count_field] += group_counts
            self.log_dict[self.count_field] += count
        self._has_log = True

    def get_log(self):
        """
        Sanitizes the internal log (Algorithm.log_dict) and outputs it.
        """
        sanitized_log = {}
        for k, v in self.log_dict.items():
            if k.startswith(self.group_prefix):
                field = k[len(self.group_prefix) :]
                for g in range(self.grouper.n_groups):
                    # set relevant values to NaN depending on the group count
                    count = self.log_dict[self.group_count_field][g].item()
                    if count == 0 and k != self.group_count_field:
                        outval = np.nan
                    else:
                        outval = v[g].item()
                    # add to dictionary with an appropriate name
                    # in practice, it is saving each value as {field}_group:{g}
                    added = False
                    for m in self.logged_metrics:
                        if field == m.name:
                            sanitized_log[m.group_metric_field(g)] = outval
                            added = True
                    if k == self.group_count_field:
                        sanitized_log[self.loss.group_count_field(g)] = outval
                        added = True
                    elif not added:
                        sanitized_log[f"{field}_group:{g}"] = outval
            else:
                assert not isinstance(v, torch.Tensor)
                sanitized_log[k] = v
        return sanitized_log

    def step_schedulers(self, is_epoch, metrics={}, log_access=False):
        """
        Updates the scheduler after an epoch.
        If a scheduler is updated based on a metric (SingleModelAlgorithm.scheduler_metric),
        then it first looks for an entry in metrics_dict and then in its internal log
        (SingleModelAlgorithm.log_dict) if log_access is True.
        Args:
            - metrics_dict (dictionary)
            - log_access (bool): whether the scheduler_metric can be fetched from internal log
                                 (self.log_dict)
        """
        for scheduler, metric_name in zip(self.schedulers, self.scheduler_metric_names):
            if scheduler is None:
                continue
            if is_epoch and scheduler.step_every_batch:
                continue
            if (not is_epoch) and (not scheduler.step_every_batch):
                continue
            self._step_specific_scheduler(
                scheduler=scheduler,
                metric_name=metric_name,
                metrics=metrics,
                log_access=log_access,
            )

    def _step_specific_scheduler(self, scheduler, metric_name, metrics, log_access):
        """
        Helper function for updating scheduler
        Args:
            - scheduler: scheduler to update
            - is_epoch (bool): epoch-wise update if set to True, batch-wise update otherwise
            - metric_name (str): name of the metric (key in metrics or log dictionary) to use for updates
            - metrics (dict): a dictionary of metrics that can beused for scheduler updates
            - log_access (bool): whether metrics from self.get_log() can be used to update schedulers
        """
        if not scheduler.use_metric or metric_name is None:
            metric = None
        elif metric_name in metrics:
            metric = metrics[metric_name]
        elif log_access:
            sanitized_log_dict = self.get_log()
            if metric_name in sanitized_log_dict:
                metric = sanitized_log_dict[metric_name]
            else:
                raise ValueError("scheduler metric not recognized")
        else:
            raise ValueError("scheduler metric not recognized")
        step_scheduler(scheduler, metric)

    def get_pretty_log_str(self):
        """
        Output:
            - results_str (str)
        """
        results_str = ""

        # Get sanitized log dict
        log = self.get_log()

        # Process aggregate logged fields
        for field in self.logged_fields:
            if field.startswith(self.group_prefix):
                continue
            results_str += f"{field}: {log[field]:.3f}\n"

        # Process aggregate logged metrics
        for metric in self.logged_metrics:
            results_str += (
                f"{metric.agg_metric_field}: {log[metric.agg_metric_field]:.3f}\n"
            )

        # Process logs for each group
        if not self.no_group_logging:
            for g in range(self.grouper.n_groups):
                group_count = log[f"count_group:{g}"]
                if group_count <= 0:
                    continue

                results_str += (
                    f"  {self.grouper.group_str(g)}  " f"[n = {group_count:6.0f}]:\t"
                )

                # Process grouped logged fields
                for field in self.logged_fields:
                    if field.startswith(self.group_prefix):
                        field_suffix = field[len(self.group_prefix) :]
                        log_key = f"{field_suffix}_group:{g}"
                        results_str += f"{field_suffix}: " f"{log[log_key]:5.3f}\t"

                # Process grouped metric fields
                for metric in self.logged_metrics:
                    results_str += (
                        f"{metric.name}: " f"{log[metric.group_metric_field(g)]:5.3f}\t"
                    )
                results_str += "\n"
        else:
            results_str += "\n"

        return results_str
