import warnings
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Optimizer
from typing import List
from .Buffer import Buffer
from .base import *
Warning('This is a work in progress. The code is not yet functional.')
def project(gxy: torch.Tensor, ger: torch.Tensor) -> torch.Tensor:
    """
    Project the gradient `gxy` onto the subspace orthogonal to `ger`.
    """
    corr = torch.dot(gxy, ger) / torch.dot(ger, ger)
    return gxy - corr * ger

class AGem(ContinualLearning):
    NAME = 'agem'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, 
                 encoder: nn.Module,
                 loss_func: nn.Module,
                 buffer_size: int = 2000,
                 device: str = 'cuda') -> None:
        super(AGem, self).__init__(encoder)
        self.loss_func = loss_func
        self.buffer = Buffer(capacity=buffer_size, device=device)
        self.device = device
        
        # Initialize gradient storage tensors
        self.grad_dims = [param.data.numel() for param in self.encoder.parameters()]
        self.grad_xy = torch.zeros(np.sum(self.grad_dims)).to(self.device)
        self.grad_er = torch.zeros(np.sum(self.grad_dims)).to(self.device)

    def store_grad(self, parameters: List[nn.Parameter], grad_vec: torch.Tensor) -> None:
        """
        Store gradients into `grad_vec`.
        """
        count = 0
        for param in parameters:
            if param.grad is not None:
                grad_len = param.grad.numel()
                grad_vec[count:count + grad_len].copy_(param.grad.view(-1))
                count += grad_len

    def overwrite_grad(self, parameters: List[nn.Parameter], new_grad: torch.Tensor) -> None:
        """
        Overwrite the current gradients with `new_grad`.
        """
        count = 0
        for param in parameters:
            if param.grad is not None:
                grad_len = param.grad.numel()
                param.grad.copy_(new_grad[count:count + grad_len].view_as(param.grad))
                count += grad_len

    def compute_loss(self, 
                    inputs: torch.Tensor,
                    labels: torch.Tensor,
                    not_aug_inputs: torch.Tensor) -> torch.Tensor:
        
        # Zero the gradients
        self.encoder.zero_grad()
        outputs = self.encoder(inputs)
        loss = self.loss_func(outputs, labels)
        loss.backward()

        if not self.buffer.is_empty():
            # Store gradients of the current task
            self.store_grad(self.encoder.parameters(), self.grad_xy)

            buf_inputs, buf_labels, _ = self.buffer.get_data(batch_size=inputs.size(0), transform=None)
            self.encoder.zero_grad()
            buf_outputs = self.encoder(buf_inputs)
            penalty = self.loss_func(buf_outputs, buf_labels)
            penalty.backward()

            # Store gradients of the buffer
            self.store_grad(self.encoder.parameters(), self.grad_er)

            # Check if the gradients need to be projected
            dot_prod = torch.dot(self.grad_xy, self.grad_er)
            if dot_prod.item() < 0:
                # Project the gradients
                g_tilde = project(gxy=self.grad_xy, ger=self.grad_er)
                self.overwrite_grad(self.encoder.parameters(), g_tilde)
            else:
                # Use the original gradients
                self.overwrite_grad(self.encoder.parameters(), self.grad_xy)

        return loss

    def end_task(self, dataset) -> None:
        """
        Store examples from the current task into the buffer.
        """
        samples_per_task = self.buffer.capacity // dataset.N_TASKS
        loader = dataset.train_loader
        cur_x, cur_y = next(iter(loader))
        self.buffer.add_data(
            examples=cur_x.to(self.device),
            labels=cur_y.to(self.device)
        )

