import logging


import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import numpy as np


_logger = logging.getLogger(__name__)


class Memory:

    def __init__(self, total_classes, exemplars_per_class):
        self.total_classes = total_classes
        self.exemplars_per_class = exemplars_per_class
        self.size = total_classes * exemplars_per_class
        _logger.info(f"Memory size: {self.size}")

        self.current_size = 0
        self.observed_classes = 0

        self.x = []
        self.y = []
        self.indices = []
        self.dataset = None

    def extract_features(self, model, loader, amp_autocast, unit_test=False):

        assert not model.training

        # Update
        all_features = []
        all_targets = []
        with torch.no_grad():
            for i, (image, target) in enumerate(loader):
                
                image = image.cuda()

                with amp_autocast():
                    batch_features = model(image, pre_logits=True).detach().cpu().numpy()

                all_features.append(batch_features)
                all_targets.append(target.cpu().numpy())

                if i % 50 == 0:
                    _logger.info(f"Extracted features for {i} batches")

                # if unit_test and i > 100:
                #     break

        features = np.concatenate(all_features, axis=0)
        targets = np.concatenate(all_targets, axis=0)

        return features, targets


    def remove(self, icarl_m):
        if self.current_size > 0:
            # Remove samples
            x, y, indices = [], [], []
            ys = np.unique(self.y)
            for class_idx in ys:
                _x = self.x[self.y == class_idx]
                _y = self.y[self.y == class_idx]
                _indices = self.indices[self.y == class_idx]
                # _logger.info(f"Class {class_idx} has {_x.shape[0]} samples. Reducing to {icarl_m}.")
                x.append(_x[:icarl_m])
                y.append(_y[:icarl_m])
                indices.append(_indices[:icarl_m])

            prev_size = self.current_size
            self.x = np.concatenate(x, axis=0)
            self.y = np.concatenate(y, axis=0)
            self.indices = np.concatenate(indices, axis=0)
            _logger.info(f"Removed {self.current_size - self.x.shape[0]} samples from memory")
    
    def update(self, img_size, model, num_classes, loader, amp_autocast, unit_test=False):
        
        if isinstance(img_size, int):
            img_size = (img_size, img_size)
        features, targets = self.extract_features(model, loader, amp_autocast, unit_test=unit_test)
        ys = np.unique(targets)
        # assert len(ys) == num_classes, f"Expected {num_classes} classes, got {len(ys)}"
        _logger.info(f"Expected {num_classes} classes, got {len(ys)}")
        observed_classes = self.observed_classes + num_classes

        # New number of exemplars per task
        icarl_m = self.size // observed_classes # For now, manually make sure that at least min number of exemplars are present
        _logger.info(f"Number of exemplars per class: {icarl_m}")
        self.remove(icarl_m)

        task_indices = []
        for y in ys:
            class_indices = np.where(targets == y)[0]
            class_features = features[class_indices]
            selected_indices = self.icarl_selection(class_features, icarl_m)
            task_indices.append(class_indices[selected_indices])
        task_indices = np.concatenate(task_indices, axis=0)

        # Go through the loader again and get the exemplars
        x = np.zeros((len(task_indices),) + (3, ) + img_size) # B, C, H, W
        y = np.zeros((len(task_indices),), dtype=np.int64)
        num_samples = 0
        for i, (images, targets) in enumerate(loader):
            batch_size = images.shape[0]

            batch_indices = np.arange(batch_size)
            _indices = batch_indices + num_samples
            indices_to_keep = np.isin(_indices, task_indices)
            
            if np.sum(indices_to_keep) == 0:
                continue
            
            _indices = _indices[indices_to_keep]
            batch_indices_to_keep = batch_indices[indices_to_keep]
            # What is the index of each index in _indices in task_indices?
            locations = np.where(np.isin(task_indices, _indices))[0]
            _images = images[batch_indices_to_keep.tolist()]
            _targets = targets[batch_indices_to_keep.tolist()]
            x[locations] = _images.cpu().numpy()
            y[locations] = _targets.cpu().numpy()

            num_samples += batch_size

            if i % 50 == 0:
                _logger.info(f"Updated {i} batches")

            # if unit_test and i > 100:
            #     break

        if len(self.x) == 0:
            self.x = x
            self.y = y
            self.indices = task_indices
        else:
            self.x = np.concatenate([self.x, x], axis=0)
            self.y = np.concatenate([self.y, y], axis=0)
            self.indices = np.concatenate([self.indices, task_indices], axis=0)

        self.dataset = TensorDataset(torch.from_numpy(self.x), torch.from_numpy(self.y))

        self.current_size = len(self.x)
        _logger.info(f"Memory size: {self.current_size}")
        self.observed_classes = observed_classes

    # From DyTox
    def icarl_selection(self, features, nb_examplars):

        D = features.T
        D = D / (np.linalg.norm(D, axis=0) + 1e-8)
        
        mu = np.mean(D, axis=1)
        herding_matrix = np.zeros((features.shape[0],))

        w_t = mu
        iter_herding, iter_herding_eff = 0, 0

        while not (
            np.sum(herding_matrix != 0) == min(nb_examplars, features.shape[0])
        ) and iter_herding_eff < 1000:
            tmp_t = np.dot(w_t, D)
            ind_max = np.argmax(tmp_t)
            iter_herding_eff += 1
            if herding_matrix[ind_max] == 0:
                herding_matrix[ind_max] = 1 + iter_herding
                iter_herding += 1

            w_t = w_t + mu - D[:, ind_max]

        herding_matrix[np.where(herding_matrix == 0)[0]] = 10000

        return herding_matrix.argsort()[:nb_examplars]
    
    def create_dataloader(self, batch_size):
        return DataLoader(self.dataset, batch_size=batch_size, shuffle=True)


class iCARLWrapper(nn.Module):

    def __init__(self, backbone, iter_backbone, head_factory):

        super().__init__()
        self.backbone = backbone
        self.iter_backbone = iter_backbone
        self.head = head_factory(backbone) # Only one head withall the task logits

        # Enable gradients for the cheem layer
        for param_name, param in self.iter_backbone(self.backbone):
            param.requires_grad = True
    
    def forward(self, x, pre_logits=False):
        x = self.backbone.forward_features(x)
        x = self.backbone.forward_head(x, pre_logits=True)
        if pre_logits:
            return x
        return self.head(x)
