import random
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.dataset import Dataset
import os
import copy
import numpy as np
import logging
from tqdm import tqdm
import operator

from sklearn import preprocessing
from sklearn.cluster import KMeans
from lang_exps.data.processors.data import output_modes, tasks_num_labels

logger = logging.getLogger(__name__)


class ReplayDataset(Dataset):
    def __init__(self, features):
        self.features = features

    def __len__(self):
        return len(self.features)

    def __getitem__(self, i):
        return self.features[i]


class ExemplarHandler:
    """Module for storing and using exemplars."""

    def __init__(self):
        self.exemplars = {}

    def get_tasks(self):

        return list(self.exemplars.keys())

    def get_task_dataset(self, task):

        dataset = self.exemplars[task]
        task_dataset = ReplayDataset(
            features=[dataset[idx] for idx in range(len(dataset))]
        )

        return task_dataset

    def read_from_file(self, cached_exemplar_dir):

        cached_exemplar_file = os.path.join(cached_exemplar_dir, "exemplars.pt")

        logger.info("Loading exemplars from cached file %s", cached_exemplar_file)
        self.exemplars = torch.load(cached_exemplar_file)

        n_exemplars = 0
        for task in self.exemplars:
            n_exemplars += len(self.exemplars[task])

        logger.info("No. of exemplars restored: {}".format(n_exemplars))

    def get_n_exemplars(self):

        n_exemplars = 0
        for task in self.exemplars:
            n_exemplars += len(self.exemplars[task])

        return n_exemplars

    def save_to_file(self, cached_exemplar_dir):

        cached_exemplar_file = os.path.join(cached_exemplar_dir, "exemplars.pt")

        logger.info("Saving exemplars to cached file %s", cached_exemplar_file)
        torch.save(self.exemplars, cached_exemplar_file)

        n_exemplars = 0
        for task in self.exemplars:
            n_exemplars += len(self.exemplars[task])

        logger.info("No. of exemplars saved to file: {}".format(n_exemplars))

    def read(self, n_examples_replay):

        # Check if we memory is empty!
        if not bool(self.exemplars):
            return None, None

        previous_tasks = list(self.exemplars.keys())
        sampled_task = random.sample(previous_tasks, 1)[0]

        dataset = self.exemplars[sampled_task]

        if n_examples_replay > len(dataset):
            n_examples_replay = len(dataset)

        indices = random.sample(range(0, len(dataset)), n_examples_replay)

        replay_dataset = ReplayDataset(features=[dataset[idx] for idx in indices])

        return replay_dataset, sampled_task

    def read_multiple_tasks(self, n_examples_replay):

        # Check if we memory is empty!
        if not bool(self.exemplars):
            return None

        previous_tasks = list(self.exemplars.keys())

        task_identifiers = []
        for task in previous_tasks:
            task_identifiers.extend(
                ["{}_{}".format(task, idx) for idx in range(len(self.exemplars[task]))]
            )

        if n_examples_replay > len(task_identifiers):
            n_examples_replay = len(task_identifiers)

        sampled_indices = random.sample(
            range(0, len(task_identifiers)), n_examples_replay
        )

        task_indices = {}
        for idx in sampled_indices:
            task, task_idx = task_identifiers[idx].split("_")
            if task not in task_indices:
                task_indices[task] = [int(task_idx)]
            else:
                task_indices[task].append(int(task_idx))

        replay_datasets = {}
        for task in task_indices:
            replay_datasets[task] = ReplayDataset(
                features=[self.exemplars[task][idx] for idx in task_indices[task]]
            )

        return replay_datasets

    def read_multiple_tasks_batch_mode(self, n_examples_replay):

        # Check if we memory is empty!
        if not bool(self.exemplars):
            return None, None

        previous_tasks = list(self.exemplars.keys())

        task_identifiers = []
        for task in previous_tasks:
            task_identifiers.extend(
                ["{}_{}".format(task, idx) for idx in range(len(self.exemplars[task]))]
            )

        if n_examples_replay > len(task_identifiers):
            n_examples_replay = len(task_identifiers)

        sampled_indices = random.sample(
            range(0, len(task_identifiers)), n_examples_replay
        )

        task_indices = {}
        for idx in sampled_indices:
            task, task_idx = task_identifiers[idx].split("_")
            if task not in task_indices:
                task_indices[task] = [int(task_idx)]
            else:
                task_indices[task].append(int(task_idx))


        replay_features = []
        replay_tasks = []

        for task in task_indices:
            for idx in task_indices[task]:
                replay_tasks.append(task)
                replay_features.append(self.exemplars[task][idx])

        return ReplayDataset(features=replay_features), replay_tasks

    def write(
        self,
        dataset,
        write_rate,
        task,
        min_examples_per_class=25,
        max_examples_per_class=-1,
    ):

        if task in self.exemplars:
            raise KeyError(
                f"Task ({task}) already exists in exemplars and is not empty."
            )

        n_examples = len(dataset)
        shuffled_indices = random.sample([i for i in range(n_examples)], n_examples)
        n_selected_examples = int(write_rate * n_examples)

        self.exemplars[task] = []

        output_mode = output_modes[task]

        if write_rate == 1.0:
            for idx in shuffled_indices:
                self.exemplars[task].append(dataset[idx])
        elif output_mode == "classification":
            label_counter = {}
            for label in range(tasks_num_labels[task]):
                label_counter[label] = 0

            n_exemplars_per_class = max(
                int(n_selected_examples / tasks_num_labels[task]),
                min_examples_per_class,
            )

            if max_examples_per_class != -1:
                # n_exemplars_per_class = min(n_exemplars_per_class, max_examples_per_class)
                n_exemplars_per_class = max_examples_per_class

            for idx in shuffled_indices:
                exemplar = dataset[idx]
                if label_counter[exemplar.label] < n_exemplars_per_class:
                    self.exemplars[task].append(exemplar)
                    label_counter[exemplar.label] += 1

            logger.info(
                "{} exemplars per class added to the memory!".format(
                    n_exemplars_per_class
                )
            )
        else:
            counter = 0
            for idx in shuffled_indices:
                if counter > n_selected_examples:
                    break
                exemplar = dataset[idx]
                self.exemplars[task].append(exemplar)
                counter += 1

        logger.info(
            "{} exemplars for task {} added to the memory!".format(
                len(self.exemplars[task]), task
            )
        )

    def write_fixed_budget(
        self,
        dataset,
        write_rate,
        task,
        min_examples_per_class=25,
        max_examples_per_class=-1,
    ):

        if task in self.exemplars:
            raise KeyError(
                f"Task ({task}) already exists in exemplars and is not empty."
            )

        n_examples = len(dataset)
        shuffled_indices = random.sample([i for i in range(n_examples)], n_examples)
        n_selected_examples = int(write_rate * n_examples)

        self.exemplars[task] = []

        output_mode = output_modes[task]

        if write_rate == 1.0:
            for idx in shuffled_indices:
                self.exemplars[task].append(dataset[idx])
        elif output_mode == "classification":
            label_counter = {}
            for label in range(tasks_num_labels[task]):
                label_counter[label] = 0

            n_exemplars_per_class = max(
                int(n_selected_examples / tasks_num_labels[task]),
                min_examples_per_class,
            )

            if max_examples_per_class != -1:
                n_exemplars_per_class = max_examples_per_class

            for idx in shuffled_indices:
                exemplar = dataset[idx]
                if label_counter[exemplar.label] < n_exemplars_per_class:
                    self.exemplars[task].append(exemplar)
                    label_counter[exemplar.label] += 1

            logger.info(
                "{} exemplars per class added to the memory!".format(
                    n_exemplars_per_class
                )
            )
        else:
            counter = 0
            for idx in shuffled_indices:
                if counter > n_selected_examples:
                    break
                exemplar = dataset[idx]
                self.exemplars[task].append(exemplar)
                counter += 1

        logger.info(
            "{} exemplars for task {} added to the memory!".format(
                len(self.exemplars[task]), task
            )
        )
