import os.path

import torch
from argparse import ArgumentTypeError

from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision.transforms import ToTensor
import numpy as np
import random
import time

from tqdm import tqdm


class ShuffleTransform:
    def __init__(self, image_size=28, channels=1):
        self.image_size = image_size
        self.channels = channels
        self.grid_pixel_indices = torch.arange(image_size * image_size * channels)
        self.pixel_indices = self.grid_pixel_indices # TODO: this might be unuseful, check it
        self.n_pixels = image_size * image_size * channels

    def update_pixels_indices(self, n):
        random_subset_of_indices = torch.randint(self.n_pixels, size=(n,))
        random_permutation = torch.randperm(len(random_subset_of_indices))
        self.pixel_indices[random_subset_of_indices] = self.pixel_indices[random_subset_of_indices][random_permutation]

    def shuffle_image(self, tensor):
        tensor = tensor.view(-1)[self.pixel_indices].view(self.channels, self.image_size, self.image_size)
        return tensor


class SequentialDataset(Dataset):
    def __init__(self, preprocessor: torch.nn.Module=None, device='cpu'):
        self.preprocessor = preprocessor
        self.device = device
        self.datasets = []

    def __len__(self):
        return len(self.datasets) * len(self.datasets[0])
    def __getitem__(self, index):
        return self.datasets[index % len(self.datasets)][index // len(self.datasets)]

    def get_dataset(self, index):
        x, y = DataLoader(self.datasets[index], len(self.datasets[index]), shuffle=False).__iter__().__next__()
        return torch.utils.data.TensorDataset(x, y)

    def get_arrays(self, index):
        x, y = DataLoader(self.datasets[index], len(self.datasets[index]), shuffle=False).__iter__().__next__()
        return x, y

    def append(self, dataset, filename=None):
        if self.preprocessor is not None:
            dataset = self._preprocess(dataset, filename)
        self.datasets.append(dataset)

    def _preprocess(self, dataset, filename=None):
        root = dataset.root if hasattr(dataset, "root") else dataset.dataset.root if hasattr(dataset, "dataset") else "."
        # print("Loading " + os.path.join(root, "dataset_" + filename + f"_{len(self.datasets)}.pt"))
        # print(f"From {os.path.abspath(os.path.curdir)}" )
        if filename is not None:
            try:
                dataset = torch.load(os.path.join(root, "dataset_" + filename + f"_{len(self.datasets)}.pt"))
                return dataset
            except FileNotFoundError:
                pass

        self.preprocessor.to(self.device)
        self.preprocessor.eval()

        data_loader = DataLoader(dataset, batch_size=32, shuffle=False)
        labels = []
        preprocess_data = []
        with torch.no_grad():
            for x,y in tqdm(data_loader, desc="Extracting dataset"):
                preprocess_data.append(self.preprocessor(x.to(self.device)))
                labels.append(y)
            preprocessed_data = torch.cat(preprocess_data)
            y = torch.cat(labels)

            dataset = torch.utils.data.TensorDataset(preprocessed_data,y)
            if filename is not None:
                torch.save(dataset, os.path.join(root, "dataset_" + filename + f"_{len(self.datasets)}.pt"))

        return dataset


# In this function I compute the replay dataset considering the size of the model, the order of the approximation
# and the size of each sample in the last train dataset.
def compute_replay_dataset(model: torch.nn.Module, last_trainset: Dataset, order: int, tasks: int,
                           old_replay_dataset: TensorDataset) -> TensorDataset:
    # computing the size of the model
    model_size = 0
    for p in model.parameters():
        model_size += p.numel()
    # computing the size of each sample in the dataset
    sample_size = 0
    for x, y in DataLoader(last_trainset, batch_size=1, shuffle=False):
        sample_size = x.numel()
        break
    # computing the size of the replay dataset
    replay_size = model_size * order // sample_size
    replay_size_per_ds = replay_size // tasks
    replay_x, replay_y = [], []

    # adding to the samples from the last replay_dataset
    if old_replay_dataset is not None:
        for x, y in DataLoader(old_replay_dataset, replay_size):
            replay_x.append(x)
            replay_y.append(y)
    # selecting some random sample from the last train dataset
    for x, y in DataLoader(last_trainset, replay_size_per_ds, shuffle=True):
        replay_x.append(x)
        replay_y.append(y)
        break
    replay_ds = torch.utils.data.TensorDataset(torch.cat(replay_x), torch.cat(replay_y))
    return replay_ds


class ElasticWeightConsolidationLoss:
    weights = []
    fisher_diagonals = []
    temporary_weights = []
    temporary_fisher_diagonals = []

    def update(self, net):
        for i, p in enumerate(net.parameters()):
            if len(self.temporary_weights) <= i:
                self.temporary_weights.append([])
                self.temporary_fisher_diagonals.append([])
            self.temporary_weights[i].append(p.clone().detach())
            # we use as Fisher information matrix diagonals the square expectations of the gradients
            self.temporary_fisher_diagonals[i].append(p.grad.clone().detach()**2)

    def __call__(self, net):
        loss = 0
        for i, weight in enumerate(net.parameters()):
            # fisher_term = torch.sum(self.fisher_diagonals[i])
            # loss += fisher_term * torch.sum((weight - self.weights[i]) ** 2)
            loss += torch.sum(self.fisher_diagonals[i]*(weight - self.weights[i]) ** 2)

        return loss

    def reset(self):
        self.weights = [w[-1] for w in self.temporary_weights]
        self.fisher_diagonals = [torch.mean(torch.stack(f), dim=0) for f in self.temporary_fisher_diagonals]
        self.temporary_weights = []
        self.temporary_fisher_diagonals = []


def compute_acc(pred, target):
    # I assume the first dimension is the batch dimension
    i_max = torch.argmax(pred, dim=1)
    diff = target - i_max
    acc = torch.mean((diff == 0).to(torch.float32)).item()
    return acc

class ArgNumber:
    def __init__(self, number_type, min_val=None, max_val=None):
        self.__number_type = number_type
        self.__min = min_val
        self.__max = max_val
        if number_type not in [int, float]:
            raise ArgumentTypeError("Invalid number type (it must be int or float)")
        if not ((self.__min is None and self.__max is None) or
                (self.__max is None) or (self.__min is not None and self.__min < self.__max)):
            raise ArgumentTypeError("Invalid range")

    def __call__(self, value):
        try:
            val = self.__number_type(value)
        except ValueError:
            raise ArgumentTypeError(f"Invalid value specified, conversion issues! Provided: {value}")
        if self.__min is not None and val < self.__min:
            raise ArgumentTypeError(f"Invalid value specified, it must be >= {self.__min}")
        if self.__max is not None and val > self.__max:
            raise ArgumentTypeError(f"Invalid value specified, it must be <= {self.__max}")
        return val


class ArgBoolean:
    def __call__(self, value):
        if isinstance(value, str):
            val = value.lower().strip()
            if val != "true" and val != "false" and val != "yes" and val != "no":
                raise ArgumentTypeError(f"Invalid value specified: {value}")
            val = True if (val == "true" or val == "yes") else False
        elif isinstance(value, int):
            if value != 0 and value != 1:
                raise ArgumentTypeError(f"Invalid value specified: {value}")
            val = value == 1
        elif isinstance(value, bool):
            val = value
        else:
            raise ArgumentTypeError(f"Invalid value specified (expected boolean): {value}")
        return val


class ArgInit:
    def __call__(self, value):
        val = None
        try:
            if ',' in value:
                s = value.split(',')
                val = [float(x) for x in s]
        except ValueError:
            val = None
        if val is None:
            try:
                val = float(value)
            except ValueError:
                if value in ['zeros', 'rand']:
                    val = value
                elif value.startswith('value'):
                    val = value
                else:
                    raise ArgumentTypeError(f"Invalid initialization specified: {value}")
        return val

def none_or_str(value):
    if value == 'None':
        return None
    return value
def set_seed(seed):
    seed = int(time.time()) if seed < 0 else int(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def elapsed_time(from_seconds, to_seconds):
    elapsed = to_seconds - from_seconds
    minutes = int(elapsed / 60.)
    hours = int(elapsed / 60. / 60.)
    seconds = elapsed - hours * 60. * 60. + minutes * 60.
    return str(hours) + " hours, " + str(minutes) + " minutes, " + f"{seconds:.3f} seconds"


cifar100classes_dict = {
    "aquatic_mammals": ["beaver", "dolphin", "otter", "seal", "whale"],
    "fish_aquarium": ["aquarium_fish", "flatfish", "ray", "shark", "trout"],
    "flowers": ["orchid", "poppy", "rose", "sunflower", "tulip"],
    "food_containers": ["bottle", "bowl", "can", "cup", "plate"],
    "fruit_and_vegetables": ["apple", "mushroom", "orange", "pear", "sweet_pepper"],
    "household_electrical_devices": ["clock", "keyboard", "lamp", "telephone", "television"],
    "household_furniture": ["bed", "chair", "couch", "table", "wardrobe"],
    "insects": ["bee", "beetle", "butterfly", "caterpillar", "cockroach"],
    "large_carnivores": ["bear", "leopard", "lion", "tiger", "wolf"],
    "large_man_made_outdoor_things": ["bridge", "castle", "house", "road", "skyscraper"],
    "large_natural_outdoor_scenes": ["cloud", "forest", "mountain", "plain", "sea"],
    "large_omnivores_and_herbivores": ["camel", "cattle", "chimpanzee", "elephant", "kangaroo"],
    "medium_mammals": ["fox", "porcupine", "possum", "raccoon", "skunk"],
    "non_insect_invertebrates": ["crab", "lobster", "snail", "spider", "worm"],
    "people": ["baby", "boy", "girl", "man", "woman"],
    "reptiles": ["crocodile", "dinosaur", "lizard", "snake", "turtle"],
    "small_mammals": ["hamster", "mouse", "rabbit", "shrew", "squirrel"],
    "trees": ["maple_tree", "oak_tree", "palm_tree", "pine_tree", "willow_tree"],
    "vehicles_1": ["bicycle", "bus", "motorcycle", "pickup_truck", "train"],
    "vehicles_2": ["lawn_mower", "rocket", "streetcar", "tank", "tractor"]
}


if __name__ == "__main__":
    def test_func():
        from networks import generate_net
        print("This is a library, not a main program! The following snippet is just for testing purposes.")

        from data.CUB200 import Cub200
        from torchvision import transforms
        ds = Cub200(root="./data", train=True, download=True, transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
        ]))
        network = "resnet18"
        preprocessor = generate_net(network, num_classes=200)
        model = generate_net("small", input_size=preprocessor.fc.in_features)
        print(model)

        sequential_dataset = SequentialDataset(preprocessor=preprocessor)
        class_sequence = torch.randperm(len(ds.classes))
        class_sequence = [class_sequence[i:i + 5] for i in range(0, len(class_sequence), 5)]
        # filename =f"dataset_cub200_{network}_train_{class_sequence[0]}"
        filename = "cub200_resnet18_test_[tensor(0), tensor(24), tensor(62), tensor(79), tensor(86)]"
        sequential_dataset.append(ds, filename=filename)

        replay_dataset = compute_replay_dataset(model, sequential_dataset, order=10)
        print(f"Size of the replay dataset {len(replay_dataset)}")

    test_func()
