from copy import deepcopy

import torch
from torch import optim

from methods.base import AdaptiveCL
from utils import Reservoir


class LEARN(AdaptiveCL):
    def __init__(
        self,
        cl_type,
        model_class,
        mixing: float,
        eta: float,
        patience_threshold: float,
        prune_min: float,
        n_classes,
        n_tasks: int,
        lr: float,
        device: torch.device,
        buffer_size: int,
        batch_size: int,
        **kwargs,
    ):
        super().__init__(cl_type, n_classes, n_tasks, device)
        self.mixing = mixing
        self.lr = lr
        self.eta = eta
        self.batch_size = batch_size
        self.buffer_size = buffer_size

        self.feature = model_class(self.n_class_each_task).to(device)
        self.fc = deepcopy(self.feature.classifier)
        self.feature.classifier = torch.nn.Identity()
        self.feature_opt = optim.SGD(self.feature.parameters(), lr=lr)

        self.expl = deepcopy(self.fc)
        self.expl_opt = optim.SGD(self.expl.parameters(), lr=self.lr)
        self.kn = []
        self.kn_opt = []
        self.kn_weight = torch.tensor([1.0], device=device)
        self.pred_weight = torch.tensor([1.0], device=device)
        self.patience_threshold, self.patience_omit = patience_threshold, 1 - 2 * mixing
        self.patience = 0
        self.prune_min = prune_min
        self.method_name = "LEARN"
        self.avg_loss = torch.ones(1, device=device)
        self.buffer = Reservoir(buffer_size, device)
        print(
            f"mixing {self.mixing}, eta {self.eta}, patience_threshold {self.patience_threshold}, prune_min {self.prune_min}"
        )

    def predict(self, inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        inputs = self.feature(inputs)
        """Makes predictions on given inputs using the exploration and knowledge models."""
        with torch.no_grad():
            predictions = self.pred_weight[0] * self.expl(inputs).softmax(1)
            if len(self.kn) > 0:
                for i, kn_ in enumerate(self.kn):
                    if self.pred_weight[i + 1] > 0:
                        predictions += self.pred_weight[i + 1] * (
                            kn_(inputs).softmax(1)
                        )
        return predictions.log()

    def update(
        self, inputs: torch.Tensor, labels: torch.Tensor, task_index: int, test=False
    ):
        """Updates the exploration and knowledge models."""
        self.recall(inputs, labels)
        if not test:
            self.exploration(inputs, labels)
            self.refinement(inputs, labels)
            self.prune()
            self.patience_detect()
        if self.count % 1000 == 0:
            avg_acc, _, _, _ = self.get_results()
            print(avg_acc[-1])
        self.buffer.add(zip(inputs, labels.view(-1, 1)))

    def exploration(self, inputs, labels):
        """Exploration step."""
        self.mode()
        self.expl_opt.zero_grad()
        self.feature_opt.zero_grad()
        feature = self.feature(inputs)
        loss = self.criterion(self.expl(feature), labels)
        if len(self.buffer) >= self.batch_size:
            with torch.no_grad():
                inputs_replay, labels_replay = self.buffer.sample(self.batch_size)
            feature_replay = self.feature(inputs_replay)
            # with torch.no_grad():
            preds_r = self.expl(feature_replay)
            loss += self.criterion(preds_r, labels_replay)
        loss.backward()
        self.expl_opt.step()
        self.feature_opt.step()

    def recall(self, inputs, labels):
        """Recall step."""
        with torch.no_grad():
            self.mode(False)
            # temp = self.pred_weight[0]
            # self.pred_weight -= self.mixing * (self.pred_weight - self.kn_weight)
            # self.pred_weight -= self.mixing * (
            #     self.pred_weight - 1 / len(self.pred_weight)
            # )
            # self.pred_weight[0] = temp
            inputs = self.feature(inputs)
            loss = torch.tensor(
                [
                    self.criterion(model(inputs), labels)
                    for model in (
                        [self.expl] if self.kn is None else [self.expl] + self.kn
                    )
                ]
            ).to(self.device)
            loss[loss > 10] = 10
            self.avg_loss -= self.mixing * (self.avg_loss - loss)
            # loss -= loss.min()
            # perf = (-loss/self.avg_loss.sqrt()).exp()
            # self.pred_weight *= perf
            self.pred_weight *= 0  # self.pred_weight.sum()
            indices = self.avg_loss < self.avg_loss.min() * 1.05
            self.pred_weight[indices] = 1 / sum(indices)

    def refinement(self, inputs, labels):
        """Refinement step."""
        self.mode()
        if self.kn is not None:
            with torch.no_grad():
                self.kn_weight -= (1 / self.count) * (self.kn_weight - self.pred_weight)

            # for i, (kn_, kn_opt_) in enumerate(zip(self.kn, self.kn_opt)):
            #     if self.pred_weight[i + 1] > self.mixing:
            #         with torch.no_grad():
            #             ratio = max(self.pred_weight[i + 1] / (self.kn_weight[i + 1] * self.count).sqrt(),
            #                         0.1 * self.lr)
            #         if ratio > self.mixing:
            #             kn_opt_.zero_grad()
            #             loss = ratio * self.criterion(kn_(inputs), labels)
            #             loss.backward()
            #             kn_opt_.step()
            # vals, indices = torch.sort(self.pred_weight, descending=True)
            # if self.pred_weight.max() > 0.5:
            i = self.avg_loss.argmin()
            if i > 0:
                kn_, kn_opt_ = self.kn[i - 1], self.kn_opt[i - 1]
                kn_opt_.zero_grad()
                self.feature_opt.zero_grad()
                with torch.no_grad():
                    ratio = max(1 / (self.count * self.kn_weight[i] + 1).sqrt(), 0.5)
                feature = self.feature(inputs)
                loss = self.criterion(kn_(feature), labels)
                if len(self.buffer) >= self.batch_size:
                    with torch.no_grad():
                        inputs_replay, labels_replay = self.buffer.sample(
                            self.batch_size
                        )
                    feature_replay = self.feature(inputs_replay)
                    # with torch.no_grad():
                    preds_r = kn_(feature_replay)
                    loss += self.criterion(preds_r, labels_replay)
                loss *= ratio
                loss.backward()
                kn_opt_.step()
                self.feature_opt.step()

    def patience_detect(self):
        """Detects if a new model should be added."""
        with torch.no_grad():
            self.patience += max(-0.5 * self.mixing, self.pred_weight[0])
            self.patience = max(self.patience, 0)
        if self.patience > self.patience_threshold:
            self.patience = 0
            self.add_model()

    def add_model(self):
        """Adds a new model to the knowledge set."""
        new_model = deepcopy(self.expl)
        self.kn.append(new_model)
        self.kn_opt.append(optim.SGD(new_model.parameters(), lr=self.lr))
        self.pred_weight = torch.cat(
            (self.pred_weight, torch.zeros(1, device=self.device))
        )
        self.kn_weight = torch.cat((self.kn_weight, torch.zeros(1, device=self.device)))
        self.kn_weight[-1] = self.kn_weight[0]
        self.kn_weight[0] = 0
        with torch.no_grad():
            self.avg_loss = torch.cat(
                (self.avg_loss, self.avg_loss.min() * torch.ones(1, device=self.device))
            )

    def prune(self):
        """Prunes the knowledge set."""
        if len(self.kn_weight) > self.n_tasks * 2:
            i = self.kn_weight[1:].argmin()
            del self.kn[i]
            del self.kn_opt[i]
            self.kn_weight = self.delete_i(self.kn_weight, i + 1)
            self.avg_loss = self.delete_i(self.avg_loss, i + 1)
            self.kn_weight /= self.kn_weight.sum()
            self.pred_weight = self.delete_i(self.pred_weight, i + 1)
            self.pred_weight /= self.pred_weight.sum()

    def delete_i(self, tensor, i):
        """Deletes the ith element of a tensor."""
        return torch.cat((tensor[:i], tensor[i + 1 :]))

    def before_fewshot_test(self):
        self.eta *= 20
        self.mixing *= 10

    def get_models(self):
        """Returns the exploration and knowledge models."""
        return {"classifier": self.kn, "feature": self.feature}

    def mode(self, is_train=True):
        """Sets the mode of the exploration and knowledge models."""

        if is_train:
            for md in [self.expl] + self.kn + [self.feature]:
                md.train()
        else:
            for md in [self.expl] + self.kn + [self.feature]:
                md.eval()
