import torch

from methods import learn


class LEARN_ablation(learn.LEARN):
    def __init__(
        self,
        cl_type,
        is_exploring: bool,
        is_recalling: bool,
        is_refining: bool,
        model_class,
        mixing: float,
        eta: float,
        patience_threshold: float,
        prune_min: float,
        n_classes,
        n_tasks: int,
        lr: float,
        buffer_size: int,
        batch_size: int,
        device: torch.device,
        **kwargs
    ):
        super().__init__(
            cl_type,
            model_class,
            mixing,
            eta,
            patience_threshold,
            prune_min,
            n_classes,
            n_tasks,
            lr,
            device,
            buffer_size,
            batch_size,
            **kwargs
        )
        self.is_exploring = is_exploring
        self.is_recalling = is_recalling
        self.is_refining = is_refining

    def exploration(self, inputs, labels):
        """Exploration step."""
        if self.is_exploring:
            self.mode()
            self.expl_opt.zero_grad()
            self.feature_opt.zero_grad()
            feature = self.feature(inputs)
            loss = self.criterion(self.expl(feature), labels)
            if len(self.buffer) >= self.batch_size:
                with torch.no_grad():
                    inputs_replay, labels_replay = self.buffer.sample(self.batch_size)
                feature_replay = self.feature(inputs_replay)
                # with torch.no_grad():
                preds_r = self.expl(feature_replay)
                loss += self.criterion(preds_r, labels_replay)
            loss.backward()
            self.expl_opt.step()
            self.feature_opt.step()

    def recall(self, inputs, labels):
        """Recall step."""
        with torch.no_grad():
            self.mode(False)
            inputs = self.feature(inputs)
            loss = torch.tensor(
                [
                    self.criterion(model(inputs), labels)
                    for model in (
                        [self.expl] if self.kn is None else [self.expl] + self.kn
                    )
                ]
            ).to(self.device)
            loss[loss > 10] = 10
            self.avg_loss -= self.mixing * (self.avg_loss - loss)

            self.pred_weight *= 0  # self.pred_weight.sum()

            if self.is_recalling:
                indices = self.avg_loss < self.avg_loss.min() * 1.05
                self.pred_weight[indices] = 1 / sum(indices)
            else:
                self.pred_weight[0] = 1

    def refinement(self, inputs, labels):
        """Refinement step."""
        self.mode()
        if self.is_refining:
            self.mode()
            if self.kn is not None:
                with torch.no_grad():
                    self.kn_weight -= (1 / self.count) * (
                        self.kn_weight - self.pred_weight
                    )
                i = self.avg_loss.argmin()
                if i > 0:
                    kn_, kn_opt_ = self.kn[i - 1], self.kn_opt[i - 1]
                    kn_opt_.zero_grad()
                    self.feature_opt.zero_grad()
                    with torch.no_grad():
                        ratio = max(
                            1 / (self.count * self.kn_weight[i] + 1).sqrt(), 0.5
                        )
                    feature = self.feature(inputs)
                    loss = self.criterion(kn_(feature), labels)
                    if len(self.buffer) >= self.batch_size:
                        with torch.no_grad():
                            inputs_replay, labels_replay = self.buffer.sample(
                                self.batch_size
                            )
                        feature_replay = self.feature(inputs_replay)
                        # with torch.no_grad():
                        preds_r = kn_(feature_replay)
                        loss += self.criterion(preds_r, labels_replay)
                    loss *= ratio
                    loss.backward()
                    kn_opt_.step()
                    self.feature_opt.step()
