from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# from .buffer import Buffer
from .base import *
from .ring_buffer import RingBuffer
Warning('This is a work in progress. The code is not yet functional.')
class HAL(ContinualLearning):
    NAME = 'hal'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, 
                 encoder: nn.Module,
                 loss_func: nn.Module,
                 buffer_size: int,
                 hal_lambda: float = 0.1,
                 beta: float = 0.5,
                 gamma: float = 0.1,
                 device: str = 'cuda',
                 n_tasks: int = 1) -> None:
        super(HAL, self).__init__(encoder)
        self.task_number = 0
        self.buffer = RingBuffer(capacity=buffer_size, n_tasks=n_tasks, device=device)
        self.hal_lambda = hal_lambda
        self.beta = beta
        self.gamma = gamma
        self.anchor_optimization_steps = 100
        self.finetuning_epochs = 1
        self.spare_model = self.encoder
        self.spare_model.to(self.device)
        self.spare_opt = torch.optim.SGD(self.spare_model.parameters(), lr=0.01)
        self.device = device
        self.anchors = torch.zeros(0, *self.encoder.input_shape).to(device)
        self.phi = None

    def end_task(self, dataset) -> None:
        self.task_number += 1
        if self.task_number > self.buffer.task_number:
            self.buffer.num_seen_examples = 0
            self.buffer.task_number = self.task_number

        if len(self.anchors) < self.task_number * dataset.N_CLASSES_PER_TASK:
            self.get_anchors(dataset)
            del self.phi

    def get_anchors(self, dataset):
        theta_t = {name: param.clone() for name, param in self.spare_model.named_parameters()}

        for _ in range(self.finetuning_epochs):
            inputs, labels = self.buffer.get_data(self.buffer.capacity, transform=None)
            self.spare_opt.zero_grad()
            out = self.spare_model(inputs)
            loss = self.loss_func(out, labels)
            loss.backward()
            self.spare_opt.step()

        theta_m = {name: param.clone() for name, param in self.spare_model.named_parameters()}

        classes_for_this_task = np.unique(dataset.train_loader.dataset.targets)

        for a_class in classes_for_this_task:
            e_t = torch.rand(self.encoder.input_shape, requires_grad=True, device=self.device)
            e_t_opt = torch.optim.SGD([e_t], lr=0.01)
            for _ in range(self.anchor_optimization_steps):
                e_t_opt.zero_grad()
                cum_loss = 0

                self.spare_model.load_state_dict(theta_m)
                loss = -torch.sum(self.loss_func(self.spare_model(e_t.unsqueeze(0)), torch.tensor([a_class]).to(self.device)))
                loss.backward()
                cum_loss += loss.item()

                self.spare_model.load_state_dict(theta_t)
                loss = torch.sum(self.loss_func(self.spare_model(e_t.unsqueeze(0)), torch.tensor([a_class]).to(self.device)))
                loss.backward()
                cum_loss += loss.item()

                loss = torch.sum(self.gamma * (self.spare_model(e_t.unsqueeze(0), returnt='features') - self.phi) ** 2)
                loss.backward()
                cum_loss += loss.item()

                e_t_opt.step()

            e_t = e_t.detach()
            self.anchors = torch.cat((self.anchors, e_t.unsqueeze(0)))
            del e_t

        self.spare_model.zero_grad()

    def compute_loss(self, 
                     inputs: torch.Tensor,
                     labels: torch.Tensor,
                     not_aug_inputs: torch.Tensor) -> torch.Tensor:
        real_batch_size = inputs.shape[0]
        if self.phi is None:
            self.phi = torch.zeros_like(self.encoder(inputs[0].unsqueeze(0), returnt='features'), requires_grad=False).to(self.device)

        if not self.buffer.is_empty():
            buf_inputs, buf_labels, _ = self.buffer.get_data(self.buffer.capacity, transform=None)
            inputs = torch.cat((inputs, buf_inputs))
            labels = torch.cat((labels, buf_labels))

        old_weights = {name: param.clone() for name, param in self.encoder.named_parameters()}

        self.encoder.zero_grad()
        outputs = self.encoder(inputs)

        loss = self.loss_func(outputs, labels)
        loss.backward()

        first_loss = loss

        if len(self.anchors) > 0:
            with torch.no_grad():
                pred_anchors = self.encoder(self.anchors)

            self.encoder.load_state_dict(old_weights)
            pred_anchors -= self.encoder(self.anchors)
            anchor_loss = self.hal_lambda * (pred_anchors ** 2).mean()
            anchor_loss.backward()

            loss = first_loss + anchor_loss

        with torch.no_grad():
            self.phi = self.beta * self.phi + (1 - self.beta) * self.encoder(inputs[:real_batch_size], returnt='features').mean(0)

        self.buffer.add_data(examples=not_aug_inputs, labels=labels[:real_batch_size])

        return loss
