from torch import linalg as LA
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models

from typing import Tuple
import tqdm
import os
import sys
import time
import urllib.request
import zipfile
import shutil

MNIST_MEAN = (0.1307,)
MNIST_STD = (0.3081,)
FASHION_MNIST_MEAN = (0.2860,)
FASHION_MNIST_STD = (0.3530,)
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)
CIFAR100_MEAN = (0.5071, 0.4867, 0.4408)
CIFAR100_STD = (0.2675, 0.2565, 0.2761)
TINY_IMAGENET_MEAN = (0.4804, 0.4482, 0.3976)
TINY_IMAGENET_STD = (0.2764, 0.2689, 0.2817)

CHECKPOINTS_DIR = "./checkpoints"
LOG_DIR = "./logs"
DATASETS_DIR = "./data"
DEFAULT_TINY_IMAGENET_IMAGE_FOLDER = "imagenet"
# directory where the extracted features (i.e., image embeddings) are saved
CACHE_DIR = "./features_cache"

inf = float('inf')
DEFAULT_SEED = 1137
SUPPORTED_DATASETS = ["mnist", "fashion", "cifar10", "cifar100", "tiny"]
INSTAHIDE_MODELS = ["nasnet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
                    "resnext29_2x64d", "resnext29_4x64d", "resnext29_8x64d", "resnext29_32x4d"]

######################## FROM INSTAHIDE [start] ########################
_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)
TOTAL_BAR_LENGTH = 86.
last_time = time.time()
begin_time = last_time


class Tee(object):
    """Redirect print to both stdout and a file."""

    def __init__(self, filename, mode="w"):
        self.file = open(filename, mode)
        self.stdout = sys.stdout

    def write(self, message):
        self.file.write(message)
        self.stdout.write(message)

    def flush(self):
        self.file.flush()
        self.stdout.flush()


def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Total: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()


def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f
######################## FROM INSTAHIDE [end] ##########################


def set_deterministic_behavior(seed: int, deterministic_algorithms: bool = True):
    """
    - Apply the most common PyTorch settings for enabling an environment with proper reproducibility.
    - It can also set the typical algorithms as deterministic, via the `deterministic_algorithms` flag, which is set to `True` by default.
    - Please check the documentation for `torch.use_deterministic_algorithms()` for a full list of affected operations.
    - By default, it uses `DEFAULT_SEED` from `utils.py` module.

    Sources:
    1. https://docs.pytorch.org/docs/stable/notes/randomness.html#reproducibility
    2. https://medium.com/@adhikareen/why-and-when-to-use-cudnn-benchmark-true-in-pytorch-training-f2700bf34289
    3. https://docs.pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
    4. https://stackoverflow.com/questions/66130547/what-is-the-difference-between-torch-backends-cudnn-deterministic-true-and-to
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
    torch.use_deterministic_algorithms(deterministic_algorithms)


def cka(x: torch.Tensor, y: torch.Tensor, eps: float = 1e-12):
    """Centered Kernel Alignment (linear version)."""
    K = x @ x.T
    L = y @ y.T
    hsic = (K * L).sum()
    denom = (LA.norm(K) * LA.norm(L)).clamp(min=eps)
    return hsic / denom


def summarize(name: str, values: torch.Tensor):
    """Print summary stats for a metric."""
    print(f"{name}:")
    print(f"  mean = {values.mean().item():.6f}")
    print(f"  std  = {values.std().item():.6f}")
    print(f"  min  = {values.min().item():.6f}")
    print(f"  max  = {values.max().item():.6f}")
    print()


def get_device():
    if torch.cuda.is_available():
        return "cuda:0"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"


def select_optimal_device() -> torch.device:
    device_str = get_device()
    return torch.device(f'{device_str}')


def download_and_prepare_tiny_imagenet():
    base_name = "tiny-imagenet-200"
    url = f"http://cs231n.stanford.edu/{base_name}.zip"
    zip_path = os.path.join(DATASETS_DIR, f"{base_name}.zip")
    extract_path = os.path.join(DATASETS_DIR, base_name)

    # Move train and val images into a single ImageFolder-style directory
    tiny_imagenet_dir = os.path.join(
        DATASETS_DIR, DEFAULT_TINY_IMAGENET_IMAGE_FOLDER)
    os.makedirs(tiny_imagenet_dir, exist_ok=True)

    print(f"Downloading Tiny ImageNet to {zip_path} ...")
    urllib.request.urlretrieve(url, zip_path)

    print("Extracting...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(DATASETS_DIR)

    # Move train images
    train_dir = os.path.join(extract_path, "train")
    for class_name in os.listdir(train_dir):
        class_dir = os.path.join(train_dir, class_name, "images")
        target_class_dir = os.path.join(tiny_imagenet_dir, class_name)
        os.makedirs(target_class_dir, exist_ok=True)
        for img in os.listdir(class_dir):
            shutil.copy(os.path.join(class_dir, img),
                        os.path.join(target_class_dir, img))
    # Move val images
    val_dir = os.path.join(extract_path, "val")
    val_annotations = os.path.join(val_dir, "val_annotations.txt")
    val_img_dir = os.path.join(val_dir, "images")
    with open(val_annotations, 'r') as f:
        for line in f:
            img, class_name, *_ = line.strip().split('\t')
            target_class_dir = os.path.join(tiny_imagenet_dir, class_name)
            os.makedirs(target_class_dir, exist_ok=True)
            shutil.copy(os.path.join(val_img_dir, img),
                        os.path.join(target_class_dir, img))

    print(f"Tiny ImageNet prepared at {tiny_imagenet_dir}")

    return tiny_imagenet_dir


class TrainingData:
    @staticmethod
    def prepare_mnist(make_3_channels: bool = False):
        num_classes: int = 10
        if make_3_channels:
            train_tf = transforms.Compose([
                transforms.RandomRotation(10),
                transforms.ToTensor(),
                transforms.Normalize(mean=MNIST_MEAN, std=MNIST_STD),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            ])
            test_tf = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=MNIST_MEAN, std=MNIST_STD),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            ])
        else:
            train_tf = transforms.Compose([
                transforms.RandomRotation(10),
                transforms.ToTensor(),
                transforms.Normalize(mean=MNIST_MEAN, std=MNIST_STD),
            ])
            test_tf = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=MNIST_MEAN, std=MNIST_STD),
            ])

        train_dataset = datasets.MNIST(
            DATASETS_DIR, train=True, download=True, transform=train_tf)
        test_dataset = datasets.MNIST(
            DATASETS_DIR, train=False, download=True, transform=test_tf)

        return train_dataset, test_dataset, num_classes

    @staticmethod
    def prepare_fashion_mnist(make_3_channels: bool = False):
        """
        - The FashionMNIST is used sometimes within ICLR paper submissions. E.g.: https://openreview.net/pdf?id=rJg851rYwH
        """
        num_classes: int = 10
        if make_3_channels:
            train_tf = transforms.Compose([
                transforms.RandomRotation(10),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=FASHION_MNIST_MEAN, std=FASHION_MNIST_STD),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            ])
            test_tf = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=FASHION_MNIST_MEAN, std=FASHION_MNIST_STD),
                transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            ])
        else:
            train_tf = transforms.Compose([
                transforms.RandomRotation(10),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=FASHION_MNIST_MEAN, std=FASHION_MNIST_STD),
            ])
            test_tf = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=FASHION_MNIST_MEAN, std=FASHION_MNIST_STD),
            ])

        train_dataset = datasets.FashionMNIST(
            DATASETS_DIR, train=True, download=True, transform=train_tf)
        test_dataset = datasets.FashionMNIST(
            DATASETS_DIR, train=False, download=True, transform=test_tf)

        return train_dataset, test_dataset, num_classes

    @staticmethod
    def prepare_cifar10(match_for_mnist: bool = False):
        num_classes: int = 10
        if match_for_mnist:
            random_crop = 28
            print(
                f'Applying RandomCrop=28 on the samples to match MNIST...')
        else:
            random_crop = 32

        train_tf = transforms.Compose([
            transforms.RandomCrop(random_crop, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD)])
        test_tf = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD)])

        train_dataset = datasets.CIFAR10(
            DATASETS_DIR, train=True, download=True, transform=train_tf)
        test_dataset = datasets.CIFAR10(
            DATASETS_DIR, train=False, download=True, transform=test_tf)

        return train_dataset, test_dataset, num_classes

    @staticmethod
    def prepare_cifar100(match_for_mnist: bool = False):
        num_classes: int = 100
        if match_for_mnist:
            random_crop = 28
            print(
                f'Applying RandomCrop=28 on the samples to match MNIST...')
        else:
            random_crop = 32

        train_tf = transforms.Compose([
            transforms.RandomCrop(random_crop, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=CIFAR100_MEAN, std=CIFAR100_STD)])
        test_tf = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=CIFAR100_MEAN, std=CIFAR100_STD)])

        train_dataset = datasets.CIFAR100(
            DATASETS_DIR, train=True, download=True, transform=train_tf)
        test_dataset = datasets.CIFAR100(
            DATASETS_DIR, train=False, download=True, transform=test_tf)

        return train_dataset, test_dataset, num_classes

    @staticmethod
    def prepare_tiny_imagenet(match_for_mnist: bool = False):
        image_folder = os.path.join(
            DATASETS_DIR, DEFAULT_TINY_IMAGENET_IMAGE_FOLDER)
        if os.path.isdir(image_folder):
            print(
                f"ImageFolder found at: {image_folder}")
        else:
            print("ImageFolder does not exist. Downloading and preparing ImageFolder...")
            download_and_prepare_tiny_imagenet()

        num_classes: int = 200
        if match_for_mnist:
            random_crop = 28
            print(
                f'Applying RandomCrop=28 on the samples to match MNIST...')
        else:
            random_crop = 32

        tf = transforms.Compose([
            transforms.Resize(40),
            transforms.RandomCrop(random_crop),
            transforms.ToTensor(),
            transforms.Normalize(mean=TINY_IMAGENET_MEAN,
                                 std=TINY_IMAGENET_STD),
        ])
        train_dataset = datasets.ImageFolder(image_folder, transform=tf)

        return train_dataset, train_dataset, num_classes


def get_dataset(dataset_type: str, make_mnist_3_channels: bool = False, match_for_mnist: bool = False) -> Tuple[Dataset, Dataset, int]:
    """
    Load and prepare standard datasets for training and validation.

    Supported dataset types:
        - ``"mnist"``: MNIST
        - ``"fashion"``: FashionMNIST
        - ``"cifar10"``: CIFAR-10
        - ``"cifar100"``: CIFAR-100
        - ``"tiny"``: TinyImageNet (custom implementation)

    For TinyImageNet, this function provides a custom loader that will
    download, extract, and prepare the dataset as a 
    :class:`torchvision.datasets.ImageFolder`, following the 
    `PyTorch docs <https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html>`_.

    Parameters
    ----------
    dataset_type : str
        The type of dataset to load. Must be one of 
        {"mnist", "fashion", "cifar10", "cifar100", "tiny"}.
    make_mnist_3_channels : bool, optional (default=False)
        If True, converts MNIST and FashionMNIST from 1x28x28 grayscale
        to 3x28x28 RGB-like format for compatibility with RGB datasets.
    match_for_mnist : bool, optional (default=False)
        If True, resizes CIFAR-10, CIFAR-100, and TinyImageNet samples
        to 28x28 so they are compatible with MNIST and FashionMNIST.

    Returns
    -------
    train_dataset : torch.utils.data.Dataset
        The training dataset.
    val_dataset : torch.utils.data.Dataset
        The validation dataset.
    num_classes : int
        The number of classes in the dataset.
    """
    if dataset_type == "mnist":
        train_dataset, test_dataset, num_classes = TrainingData.prepare_mnist(
            make_mnist_3_channels)
    elif dataset_type == "fashion":
        train_dataset, test_dataset, num_classes = TrainingData.prepare_fashion_mnist(
            make_mnist_3_channels)
    elif dataset_type == "cifar10":
        train_dataset, test_dataset, num_classes = TrainingData.prepare_cifar10(
            match_for_mnist=match_for_mnist)
    elif dataset_type == "cifar100":
        train_dataset, test_dataset, num_classes = TrainingData.prepare_cifar100(
            match_for_mnist=match_for_mnist)
    elif dataset_type == "tiny":
        train_dataset, test_dataset, num_classes = TrainingData.prepare_tiny_imagenet(
            match_for_mnist=match_for_mnist)
    else:
        raise ValueError(
            f"Dataset type not supported. Currently supported datasets: {SUPPORTED_DATASETS}")

    return train_dataset, test_dataset, num_classes


class StandardDatasets:
    mnist_mean = (0.1307,)
    mnist_std = (0.3081,)
    fashion_mnist_mean = (0.2860,)
    fashion_mnist_std = (0.3530,)
    cifar10_mean = (0.4914, 0.4822, 0.4465)
    cifar10_std = (0.2470, 0.2435, 0.2616)
    cifar100_mean = (0.5071, 0.4867, 0.4408)
    cifar100_std = (0.2675, 0.2565, 0.2761)
    tiny_imagenet_mean = (0.4804, 0.4482, 0.3976)
    tiny_imagenet_std = (0.2764, 0.2689, 0.2817)
    imagenet_mean = (0.485, 0.456, 0.406)
    imagenet_std = (0.229, 0.224, 0.225)

    def __init__(self, dataset_type: str):
        self.train_dataset = None
        self.test_dataset = None
        self.num_classes = None
        if dataset_type == "mnist":
            self.train_dataset, self.test_dataset, self.num_classes = self.mnist()
        elif dataset_type == "fashion":
            self.train_dataset, self.test_dataset, self.num_classes = self.fashion_mnist()
        elif dataset_type == "cifar10":
            self.train_dataset, self.test_dataset, self.num_classes = self.cifar10()
        elif dataset_type == "cifar100":
            self.train_dataset, self.test_dataset, self.num_classes = self.cifar100()
        else:
            raise ValueError("Unsupported dataset type.")

    @staticmethod
    def mnist(resize_to: int = 224):
        tf = transforms.Compose([
            transforms.Resize(resize_to),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=StandardDatasets.mnist_mean, std=StandardDatasets.mnist_std),
        ])
        trainset = datasets.MNIST(
            DATASETS_DIR, train=True, transform=tf, download=True)
        testset = datasets.MNIST(
            DATASETS_DIR, train=False, transform=tf, download=True)

        return trainset, testset, 10

    @staticmethod
    def fashion_mnist(resize_to: int = 224):
        tf = transforms.Compose([
            transforms.Resize(resize_to),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=StandardDatasets.fashion_mnist_mean, std=StandardDatasets.fashion_mnist_std),
        ])
        trainset = datasets.FashionMNIST(
            DATASETS_DIR, train=True, transform=tf, download=True)
        testset = datasets.FashionMNIST(
            DATASETS_DIR, train=False, transform=tf, download=True)

        return trainset, testset, 10

    @staticmethod
    def cifar10(resize_to: int = 224):
        tf = transforms.Compose([
            transforms.Resize(resize_to),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=StandardDatasets.cifar10_mean, std=StandardDatasets.cifar10_std),
        ])
        trainset = datasets.CIFAR10(
            DATASETS_DIR, train=True, transform=tf, download=True)
        testset = datasets.CIFAR10(
            DATASETS_DIR, train=False, transform=tf, download=True)

        return trainset, testset, 10

    @staticmethod
    def cifar100(resize_to: int = 224):
        tf = transforms.Compose([
            transforms.Resize(resize_to),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=StandardDatasets.cifar100_mean, std=StandardDatasets.cifar100_std),
        ])
        trainset = datasets.CIFAR100(
            DATASETS_DIR, train=True, transform=tf, download=True)
        testset = datasets.CIFAR100(
            DATASETS_DIR, train=False, transform=tf, download=True)

        return trainset, testset, 100


class FeatureGenerator:
    def __init__(self, feature_extractor_type: str, remove_last_k_layers: int, dataset_type: str, device: torch.device, use_cache: bool, batch_size: int = 256) -> None:
        os.makedirs(CACHE_DIR, exist_ok=True)
        self.train_features_path = f'{CACHE_DIR}/{feature_extractor_type}-k{remove_last_k_layers}-{dataset_type}_train_features.pth'
        self.train_labels_path = f'{CACHE_DIR}/{feature_extractor_type}-k{remove_last_k_layers}-{dataset_type}_train_labels.pth'
        self.test_features_path = f'{CACHE_DIR}/{feature_extractor_type}-k{remove_last_k_layers}-{dataset_type}_test_features.pth'
        self.test_labels_path = f'{CACHE_DIR}/{feature_extractor_type}-k{remove_last_k_layers}-{dataset_type}_test_labels.pth'
        self.device = device
        self.batch_size = batch_size
        self.validate_num_classes(dataset_type=dataset_type)
        if self.check_features_exist() and use_cache:
            print(f'Using cached features...')
        else:
            print(f'Local features not found. Generating new ones...')
            self._data = StandardDatasets(dataset_type)
            self.instantiate_model(
                feature_extractor_type=feature_extractor_type, remove_last_k_layers=remove_last_k_layers)
            self.generate_features(train=True)
            self.generate_features(train=False)

    def validate_num_classes(self, dataset_type: str):
        if dataset_type == "mnist" or dataset_type == "fashion" or dataset_type == "cifar10":
            self.num_classes: int = 10
        elif dataset_type == "cifar100":
            self.num_classes: int = 100
        else:
            raise ValueError(
                f"Unsupported dataset type for `num_classes` validation: {dataset_type}")

    def validate_features(self):
        _files = [self.train_features_path, self.train_labels_path,
                  self.test_features_path, self.test_labels_path]
        return all(list(map(os.path.isfile, _files)))

    def instantiate_model(self, feature_extractor_type: str, remove_last_k_layers: int):
        if feature_extractor_type == "resnet18":
            _weights = models.ResNet18_Weights.DEFAULT
        elif feature_extractor_type == "resnet34":
            _weights = models.ResNet34_Weights.DEFAULT
        elif feature_extractor_type == "resnet50":
            _weights = models.ResNet50_Weights.DEFAULT
        else:
            raise ValueError(
                f"Invalid type for the feature extractor")
        self.feature_extractor = models.__dict__[
            feature_extractor_type](weights=_weights)
        self.feature_extractor = nn.Sequential(
            *list(self.feature_extractor.children())[:-remove_last_k_layers])
        self.feature_extractor.to(self.device)
        self.feature_extractor.eval()

    def check_features_exist(self):
        return self.check_train_features_exist() and self.check_test_features_exist()

    def check_train_features_exist(self):
        if os.path.isfile(self.train_features_path) and os.path.isfile(self.train_labels_path):
            return True
        return False

    def check_test_features_exist(self):
        if os.path.isfile(self.test_features_path) and os.path.isfile(self.test_labels_path):
            return True
        return False

    def generate_features(self, train: bool):
        if train:
            dataloader = DataLoader(
                self._data.train_dataset, batch_size=self.batch_size, shuffle=True)
        else:
            dataloader = DataLoader(
                self._data.test_dataset, batch_size=self.batch_size, shuffle=False)

        features = []
        labels = []
        with torch.no_grad():
            for batch_images, batch_labels in tqdm.tqdm(dataloader, desc=f'Generating {"train" if train else "test"} features'):
                batch_images = batch_images.to(self.device)
                extracted_features = self.feature_extractor(batch_images)
                features.append(extracted_features)
                labels.append(batch_labels)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        if torch.mps.is_available():
            torch.mps.empty_cache()

        features = torch.cat(features, dim=0).detach().cpu()
        labels = torch.cat(labels, dim=0).detach().cpu()

        if train:
            torch.save(features, self.train_features_path)
            torch.save(labels, self.train_labels_path)
        else:
            torch.save(features, self.test_features_path)
            torch.save(labels, self.test_labels_path)
