import os
import random
from pathlib import Path

import torch
from tqdm import tqdm


def get_folder_names(directory, suffix=""):
    return [
        item + suffix
        for item in os.listdir(directory)
        if os.path.isdir(os.path.join(directory, item))
    ]


class RecordVals:
    def __init__(self):
        self.record = []

    def add(self, val):
        self.record.append(val)


def get_task_indices(dataloader, n_class_per_task):
    task_indices = torch.tensor(
        [labels[0] // n_class_per_task for _, labels in dataloader]
    )
    return task_indices


def get_adapt_weights(task_indices, gamma):
    """calculate the adapt_weights given changepoints and gamma"""
    n = task_indices.shape[0]
    adapt_weights = torch.ones(n)
    for i in range(n - 1):
        if task_indices[i + 1] == task_indices[i]:
            adapt_weights[i + 1] = gamma * adapt_weights[i]
    return adapt_weights


def calc_adapt(acc_list, adapt_weights):
    """calculate the Adaptiveness given adapt_weights"""
    adapt_weights_ = adapt_weights.clone().to(acc_list.device)
    return (acc_list * adapt_weights_).sum() / adapt_weights_.sum()


def verbose_iterator(iterable, verbose):
    if verbose:
        return tqdm(iterable)
    else:
        return iterable


class Reservoir:
    """
    A class that implements a Reservoir Sampling algorithm for continual learning.
    """

    def __init__(self, buffer_size: int, device: torch.device):
        self.buffer = {"inputs": None, "labels": None}
        self.buffer_size = buffer_size
        self.device = device
        self.N = -1

    def __len__(self):
        return len(self.buffer["labels"]) if self.buffer["labels"] is not None else 0

    def add(self, new_data):
        """Adds new data to the reservoir."""
        for x, y in new_data:
            self.N += 1
            x, y = x.unsqueeze(0), y.unsqueeze(0)
            if len(self) < self.buffer_size:
                if len(self) == 0:
                    self.buffer["inputs"] = x
                    self.buffer["labels"] = y
                else:
                    self.buffer["inputs"] = torch.cat((self.buffer["inputs"], x), 0)
                    self.buffer["labels"] = torch.cat((self.buffer["labels"], y), 0)
            else:
                random_index = random.randint(0, self.N)
                if random_index < self.buffer_size:
                    (
                        self.buffer["inputs"][random_index, :],
                        self.buffer["labels"][random_index],
                    ) = (x, y)

    def sample(self, batch_size: int):
        """Samples a batch of data from the reservoir."""
        if len(self) >= batch_size:
            with torch.no_grad():
                indices = random.sample(range(len(self)), k=batch_size)
                return self.buffer["inputs"][indices].to(self.device), self.buffer[
                    "labels"
                ][indices].view(-1).to(self.device)
        else:
            raise ValueError(
                "Batch size is larger than the current size of the reservoir."
            )


class GreedyBalancingSampler:
    """
    Greedy Balancing Sampler implementation
    """

    def __init__(self, buffer_size: int, device: torch.device):
        self.buffer = {"inputs": None, "labels": None}
        self.buffer_size = buffer_size
        self.device = device
        self.counter = [0]

    def __len__(self):
        return len(self.buffer["labels"]) if self.buffer["labels"] is not None else 0

    def add(self, new_data):
        """Adds new data to the reservoir."""
        for x, y in new_data:
            x, y = x.unsqueeze(0), y.unsqueeze(0)
            if y >= len(self.counter):
                self.counter = self.counter + [0] * (y + 1 - len(self.counter))
            k_c = (
                self.buffer_size / len(self.counter)
                if sum(self.counter) != 0
                else self.buffer_size
            )
            if self.counter[y] < k_c:
                if sum(self.counter) > self.buffer_size:
                    max_label = torch.argmax(torch.tensor(self.counter))
                    index = torch.argwhere(self.buffer["labels"].view(-1) == max_label)
                    remove_index = index[torch.randint(0, len(index), (1,)).item()]
                    (
                        self.buffer["inputs"][remove_index, :],
                        self.buffer["labels"][remove_index],
                    ) = (x, y)
                    self.counter[max_label] -= 1

                else:
                    if len(self) == 0:
                        self.buffer["inputs"] = x
                        self.buffer["labels"] = y
                    else:
                        self.buffer["inputs"] = torch.cat((self.buffer["inputs"], x), 0)
                        self.buffer["labels"] = torch.cat((self.buffer["labels"], y), 0)

                self.counter[y] += 1

    def sample(self, batch_size: int):
        """Samples a batch of data from the reservoir."""
        if len(self) >= batch_size:
            with torch.no_grad():
                indices = random.sample(range(len(self)), k=batch_size)
                return self.buffer["inputs"][indices].to(self.device), self.buffer[
                    "labels"
                ][indices].view(-1).to(self.device)
        else:
            raise ValueError(
                "Batch size is larger than the current size of the reservoir."
            )


def save_results(
    location,
    avg_acc,
    avg_loss,
    test_acc_list,
    adapt_acc,
    adapt_loss,
    # model_size,
    name_prefix="",
):
    # Saving results
    for result, name_suffix in zip(
        [avg_acc, avg_loss, test_acc_list, adapt_acc, adapt_loss],
        [
            "avg_acc",
            "avg_loss",
            "test_acc_list",
            "adapt_acc",
            "adapt_loss",
            # "model_size",
            "fewshot_adapt",
        ],
    ):
        save_path = Path(location)
        save_path.mkdir(parents=True, exist_ok=True)
        torch.save(result, save_path / f"{name_prefix}_{name_suffix}.pt")
