"""Classification networks to evaluate heterogeneity."""
import os
import re

import dataclasses
import numpy as np
import nltk
import torch
import torchvision.models as models
from joblib import Memory
from nltk.corpus import wordnet as wn
from sklearn.cluster import DBSCAN, SpectralClustering
from torchvision import datasets
from torchvision import transforms as T
import tensorflow as tf
import tensorflow_hub as hub

from ..._utils import pairwise_call
from .base import BaseNet

memory = Memory('joblib_cache')


class ImageNetAFolder(datasets.ImageFolder):
    """Load the ImageNetA dataset. This wrapper is needed as the ImageNetA
    classes are 200 subclasses of ImageNet1K. For the samples of ImageNetA to
    be used with networks trained on ImageNet1K, we need to make the samples
    idx match."""

    def __init__(self, root, transform=None, target_transform=None,
                 is_valid_file=None, subfolders_are_idx=False):
        self.is_valid_file = is_valid_file
        # Whether name of subfolders are class idx or not
        self.subfolders_are_idx = subfolders_are_idx
        super().__init__(root=root, transform=transform,
                         target_transform=target_transform,
                         is_valid_file=is_valid_file)

    def find_classes(self, directory):
        classes, _class_to_idx = super().find_classes(directory)

        # Retrieve the class_to_idx of ImageNet1K
        ds = datasets.ImageNet('datasets', split='val')
        # Store wnids as in ImageNet class
        self.wnids = ds.wnids

        if self.subfolders_are_idx:
            return classes, _class_to_idx

        class_to_idx = {k: ds.wnid_to_idx[k] for k in _class_to_idx.keys()}
        # class_to_idx contains the 200 class names with their associated
        # id in ImageNet1K

        available_classes = set()
        for target_class in sorted(class_to_idx.keys()):
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    if self.is_valid_file is None or self.is_valid_file(path):
                        if target_class not in available_classes:
                            available_classes.add(target_class)

        if not available_classes:
            raise ValueError('No class having valid filepath samples found.')

        class_to_idx = {k: v for k, v in class_to_idx.items() if k in available_classes}

        return classes, class_to_idx


class ImageNetBased(BaseNet):

    def __init__(self, split='val'):
        self.split = split
        super().__init__()

    def get_transform(self):
        normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        transform = T.Compose([T.Resize(256), T.CenterCrop(224),
                               T.ToTensor(), normalize])
        return transform

    def get_dataset(self):
        transform = self.get_transform()
        if self.split in ['train', 'val']:
            return datasets.ImageNet('datasets', split=self.split,
                                     transform=transform)

        r = re.match(r'test_([^:]*):?(.*)', self.split)

        if r is None:
            raise ValueError(f'Unknown split {self.split}.')

        is_valid_file = None
        subfolders_are_idx = False

        if r.group(1) == 'a':
            ds_path = 'datasets/imagenet-a/'

        elif r.group(1) == 'r':
            ds_path = 'datasets/imagenet-r/'
            sublabel = r.group(2)
            if sublabel:
                def is_valid_file(s): return bool(re.match(f'.*/{sublabel}_.*', s))

        elif r.group(1) == 'o':
            ds_path = 'datasets/imagenet-o/'

        elif r.group(1) == 'v2.1':
            ds_path = 'datasets/imagenetv2-matched-frequency-format-val'
            subfolders_are_idx = True

        elif r.group(1) == 'v2.2':
            ds_path = 'datasets/imagenetv2-threshold0.7-format-val'
            subfolders_are_idx = True

        elif r.group(1) == 'v2.3':
            ds_path = 'datasets/imagenetv2-top-images-format-val'
            subfolders_are_idx = True

        elif r.group(1) == 'c':
            sublabel = r.group(2)
            subr = re.match('([^0-9]*)([0-9]*)', sublabel)
            sublabel = subr.group(1)
            level = subr.group(2)
            ds_path = f'datasets/imagenet-c/{sublabel}/{level}/'

        elif r.group(1) == 'p':
            sublabel = r.group(2)
            ds_path = f'datasets/imagenet-p/{sublabel}'

        else:
            raise ValueError(f'Unknown split {self.split}.')

        return ImageNetAFolder(ds_path, transform=transform,
                               is_valid_file=is_valid_file,
                               subfolders_are_idx=subfolders_are_idx,
                               )

    def get_dataset_name(self):
        return f'ILSVRC2012_img_{self.split}'

    def get_w(self):
        return self.last_layer.weight.detach()

    def get_intercept(self):
        return self.last_layer.bias.detach()

    def logits_to_scores(self, y_logits):
        return torch.nn.functional.softmax(y_logits, dim=1)

    def get_class_names(self):
        dataset = self.get_dataset()
        return dataset.wnids

    def _get_similarity_matrix(self, similarity):
        """Compute pairwise similarity of ImageNet classes.

        Parameters
        ----------
        similarity : str
            The similarity function to use. Choices: 'path', 'lch', 'wup',
            'res', 'jcn', 'lin'.
            'path': path distance similarity.
            'lch': Leacock Chodorow similarity
            'wup': Wu-Palmer similarity
            'res': Resnik similarity
            'jcn': Jiang-Conrath similarity
            'lin': Lin similarity

        """
        ds = self.get_dataset()
        nltk.download('wordnet')
        nltk.download('omw-1.4')
        synsets = [wn.synset_from_pos_and_offset(wnid[0], int(wnid[1:])) for wnid in ds.wnids]

        fs = {
            'path': wn.path_similarity,
            'lch': wn.lch_similarity,
            'wup': wn.wup_similarity,
        }
        if similarity not in fs.keys():
            raise ValueError(f'similarity must be in {list(fs.keys())}. '
                             f'Given {similarity}.')
        f = fs[similarity]

        cached_pairwise_call = memory.cache(pairwise_call, ignore=['n_jobs', 'verbose'])

        return cached_pairwise_call(synsets, f, symmetric=True, n_jobs=1, verbose=1)

    def get_meta_classes(self, similarity='path', clustering='dbscan'):
        D = self._get_similarity_matrix(similarity)
        print(D)

        r = re.match('([A-z]*)([0-9]*)', clustering)

        clustering_name = r.group(1)

        if clustering_name == 'dbscan':
            estimator = DBSCAN(eps=0.1, min_samples=5, metric='precomputed')

        elif clustering_name == 'spectral':
            n_clusters = int(r.group(2))
            estimator = SpectralClustering(n_clusters=n_clusters,
                                           affinity='precomputed')

        else:
            raise ValueError(f'Unknown clustering method {clustering}.')

        labels = estimator.fit_predict(D)

        import numpy as np
        print(np.unique(labels, return_counts=True))
        print('Number of clusters', len(np.unique(labels)))

        return labels


class ImageNet21KBased(BaseNet):

    def get_dataset(self):
        normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        transform = T.Compose([T.Resize(256), T.CenterCrop(224),
                               T.ToTensor(), normalize])
        return datasets.ImageFolder('datasets/winter21_whole',
                                    transform=transform)

    def get_dataset_name(self):
        return 'winter21_whole'

    def get_w(self):
        return self.last_layer.weight.detach()

    def get_intercept(self):
        return self.last_layer.bias.detach()

    def logits_to_scores(self, y_logits):
        return torch.nn.functional.softmax(y_logits, dim=1)

    def get_class_names(self):
        return [str(i) for i in range(self.last_layer.out_features)]


class VGG(ImageNetBased):

    def __init__(self, type='11', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        if self.type == '11':
            return models.vgg11(pretrained=True)
        if self.type == '13':
            return models.vgg13(pretrained=True)
        if self.type == '16':
            return models.vgg16(pretrained=True)
        if self.type == '19':
            return models.vgg19(pretrained=True)
        if self.type == '11_bn':
            return models.vgg11_bn(pretrained=True)
        if self.type == '13_bn':
            return models.vgg13_bn(pretrained=True)
        if self.type == '16_bn':
            return models.vgg16_bn(pretrained=True)
        if self.type == '19_bn':
            return models.vgg19_bn(pretrained=True)
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.classifier[6]
        del model.classifier[6]
        return model, last_layer


class ResNet(ImageNetBased):

    def __init__(self, type='18', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        if self.type == '18':
            return models.resnet18(pretrained=True)
        if self.type == '34':
            return models.resnet34(pretrained=True)
        if self.type == '50':
            return models.resnet50(pretrained=True)
        if self.type == '101':
            return models.resnet101(pretrained=True)
        if self.type == '152':
            return models.resnet152(pretrained=True)
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.fc
        model.fc = torch.nn.Identity()
        return model, last_layer


class AlexNet(ImageNetBased):

    def create_model(self):
        return models.alexnet(pretrained=True)

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.classifier[6]
        del model.classifier[6]
        return model, last_layer


class SqueezeNet(ImageNetBased):

    def __init__(self, type='0', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        if self.type == '0':
            return models.squeezenet1_0(pretrained=True)
        if self.type == '1':
            return models.squeezenet1_1(pretrained=True)
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.fc
        return model, last_layer


class DenseNet(ImageNetBased):

    def __init__(self, type='121', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        if self.type == '121':
            return models.densenet121(pretrained=True)
        if self.type == '169':
            return models.densenet169(pretrained=True)
        if self.type == '161':
            return models.densenet161(pretrained=True)
        if self.type == '201':
            return models.densenet201(pretrained=True)
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.classifier
        model.classifier = torch.nn.Identity()
        return model, last_layer


class Inception(ImageNetBased):

    def create_model(self):
        return models.inception_v3(pretrained=True)

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.fc
        model.fc = torch.nn.Identity()
        return model, last_layer


class GoogLeNet(ImageNetBased):

    def create_model(self):
        return models.googlenet(pretrained=True)

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.fc
        model.fc = torch.nn.Identity()
        return model, last_layer


class ShuffleNet(ImageNetBased):

    def __init__(self, type='0_5', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        if self.type == '0_5':
            return models.shufflenet_v2_x0_5(pretrained=True)
        if self.type == '1_0':
            return models.shufflenet_v2_x1_0(pretrained=True)
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')
        # Not available:
        # if self.type == '1_5':
        #     return models.shufflenet_v2_x1_5(pretrained=True)
        # if self.type == '2_0':
            # return models.shufflenet_v2_x2_0(pretrained=True)

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.fc
        model.fc = torch.nn.Identity()
        return model, last_layer


class MobileNet(ImageNetBased):

    def __init__(self, type='v2', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        if self.type == 'v2':
            return models.mobilenet_v2(pretrained=True)
        if self.type == 'v3L':
            return models.mobilenet_v3_large(pretrained=True)
        if self.type == 'v3S':
            return models.mobilenet_v3_small(pretrained=True)
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')

    def create_truncated_model(self):
        model = self.create_model()
        if self.type in ['v3L', 'v3S']:
            idx = 3
        elif self.type == 'v2':
            idx = 1
        last_layer = model.classifier[idx]
        del model.classifier[idx]

        return model, last_layer


class ResNext(ImageNetBased):

    def __init__(self, type='50', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        if self.type == '50':
            return models.resnext50_32x4d(pretrained=True)
        if self.type == '101':
            return models.resnext101_32x8d(pretrained=True)
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.fc
        model.fc = torch.nn.Identity()
        return model, last_layer


class WideResNet(ImageNetBased):

    def __init__(self, type='50', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        if self.type == '50':
            return models.wide_resnet50_2(pretrained=True)
        if self.type == '101':
            return models.wide_resnet101_2(pretrained=True)
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.fc
        model.fc = torch.nn.Identity()
        return model, last_layer


class MNASNet(ImageNetBased):

    def __init__(self, type='0_5', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        if self.type == '0_5':
            return models.mnasnet0_5(pretrained=True)
        if self.type == '1_0':
            return models.mnasnet1_0(pretrained=True)
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')
        # Not available
        # if self.type == '0_75':
        #     return models.mnasnet0_75(pretrained=True)
        # if self.type == '1_3':
        #     return models.mnasnet1_3(pretrained=True)

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.classifier[1]
        del model.classifier[1]
        return model, last_layer


class EfficientNet(ImageNetBased):

    def __init__(self, type='b0', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        if self.type == 'b0':
            return models.efficientnet_b0(pretrained=True)
        if self.type == 'b1':
            return models.efficientnet_b1(pretrained=True)
        if self.type == 'b2':
            return models.efficientnet_b2(pretrained=True)
        if self.type == 'b3':
            return models.efficientnet_b3(pretrained=True)
        if self.type == 'b4':
            return models.efficientnet_b4(pretrained=True)
        if self.type == 'b5':
            return models.efficientnet_b5(pretrained=True)
        if self.type == 'b6':
            return models.efficientnet_b6(pretrained=True)
        if self.type == 'b7':
            return models.efficientnet_b7(pretrained=True)
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.classifier[1]
        del model.classifier[1]
        return model, last_layer


class RegNet(ImageNetBased):

    def __init__(self, type='y_400mf', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        choices = {
            'y_400mf': models.regnet_y_400mf(pretrained=True),
            'y_800mf': models.regnet_y_800mf(pretrained=True),
            'y_1_6gf': models.regnet_y_1_6gf(pretrained=True),
            'y_3_2gf': models.regnet_y_3_2gf(pretrained=True),
            'y_8gf': models.regnet_y_8gf(pretrained=True),
            'y_16gf': models.regnet_y_16gf(pretrained=True),
            'y_32gf': models.regnet_y_32gf(pretrained=True),
            'x_400mf': models.regnet_x_400mf(pretrained=True),
            'x_800mf': models.regnet_x_800mf(pretrained=True),
            'x_1_6gf': models.regnet_x_1_6gf(pretrained=True),
            'x_3_2gf': models.regnet_x_3_2gf(pretrained=True),
            'x_8gf': models.regnet_x_8gf(pretrained=True),
            'x_16gf': models.regnet_x_16gf(pretrained=True),
            'x_32gf': models.regnet_x_32gf(pretrained=True),
        }
        if self.type in choices:
            return choices[self.type]
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.fc
        model.fc = torch.nn.Identity()
        return model, last_layer


class VisionTransformer(ImageNetBased):

    def __init__(self, type='b_16', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        choices = {
            'b_16': models.vit_b_16(pretrained=True),
            'b_32': models.vit_b_32(pretrained=True),
            'l_16': models.vit_l_16(pretrained=True),
            'l_32': models.vit_l_32(pretrained=True),
        }
        if self.type in choices:
            return choices[self.type]
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.heads.head
        model.heads.head = torch.nn.Identity()
        return model, last_layer


class ConvNeXt(ImageNetBased):

    def __init__(self, type='tiny', split='val'):
        self.type = type
        super().__init__(split)

    def create_model(self):
        choices = {
            'tiny': models.convnext_tiny(pretrained=True),
            'small': models.convnext_small(pretrained=True),
            'base': models.convnext_base(pretrained=True),
            'large': models.convnext_large(pretrained=True),
        }
        if self.type in choices:
            return choices[self.type]
        raise ValueError(f'Unknown version {self.type} for '
                         f'{self.__class__.__name__.lower()}.')

    def create_truncated_model(self):
        model = self.create_model()
        last_layer = model.classifier[2]
        model.classifier[2] = torch.nn.Identity()
        return model, last_layer


@dataclasses.dataclass(frozen=True)
class PreprocessImages:
  """Code taken from https://github.com/google-research/vision_transformer

  Resizes images and sets value range to [-1, 1].

  This class can be used to tokenize batches of text tokens to numpy arrays
  (by calling `__call__()`), or as part of a TensorFlow preprocessing graph
  (via the method `preprocess_tf()`).

  Attributes:
    size: Target size of images.
    crop: If set to true, then the image will first be resized maintaining the
      original aspect ratio, and then a central crop of that resized image will
      be returned.
  """
  size: int
  crop: bool = False

  def _resize_small(self, image):  # pylint: disable=missing-docstring
    h, w = tf.shape(image)[0], tf.shape(image)[1]

    # Figure out the necessary h/w.
    ratio = (
        tf.cast(self.size, tf.float32) /
        tf.cast(tf.minimum(h, w), tf.float32))
    h = tf.cast(tf.round(tf.cast(h, tf.float32) * ratio), tf.int32)
    w = tf.cast(tf.round(tf.cast(w, tf.float32) * ratio), tf.int32)

    return tf.image.resize(image, (h, w), method='bilinear')

  def _crop(self, image):
    h, w = self.size, self.size
    dy = (tf.shape(image)[0] - h) // 2
    dx = (tf.shape(image)[1] - w) // 2
    return tf.image.crop_to_bounding_box(image, dy, dx, h, w)

  def _resize(self, image):
    return tf.image.resize(
        image, size=[self.size, self.size], method='bilinear')

  def _value_range(self, image):
    image = tf.cast(image, tf.float32) / 255
    return -1 + image * 2

  def preprocess_tf(self, image):
    """Resizes a single image as part of a TensorFlowg graph."""
    assert image.dtype == tf.uint8
    if self.crop:
      image = self._resize_small(image)
      image = self._crop(image)
    else:
      image = self._resize(image)
    image = tf.cast(image, tf.uint8)
    return self._value_range(image)

  def __call__(self, images):
    """Resizes a sequence of images, returns a numpy array."""
    return np.stack([
        self.preprocess_tf(tf.constant(image)) for image in images
    ])


class MLPMixerTransform(object):

    def __call__(self, img):
        img = 2*img - 1
        img = img.transpose(0, 2)
        return img

    def __repr__(self):
        return 'mlpmixer: normalize to [-1, 1]'

class MLPMixer(ImageNetBased):

    def __init__(self, type='b16', split='val'):
        self.type = type
        super().__init__(split)

    def get_transform(self):
        transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), MLPMixerTransform()])
        return transform

    def create_model(self):

        if self.type == 'b16':
            url = 'https://tfhub.dev/sayakpaul/mixer_b16_i1k_classification/1'
        elif self.type == 'l16':
            url = 'https://tfhub.dev/sayakpaul/mixer_l16_i1k_classification/1'
        elif self.type == 'sam_b16':
            url = 'https://tfhub.dev/sayakpaul/mixer_b16_sam_classification/1'
        elif self.type == 'sam_b32':
            url = 'https://tfhub.dev/sayakpaul/mixer_b32_sam_classification/1'
        else:
            raise ValueError(f'Unknown version {self.type} for '
                            f'{self.__class__.__name__.lower()}.')

        model = tf.keras.Sequential([
            hub.KerasLayer(url)
        ])
        return model

    def create_truncated_model(self):
        if self.type == 'b16':
            url = 'https://tfhub.dev/sayakpaul/mixer_b16_i1k_fe/1'
        elif self.type == 'l16':
            url = 'https://tfhub.dev/sayakpaul/mixer_l16_i1k_fe/1'
        elif self.type == 'sam_b16':
            url = 'https://tfhub.dev/sayakpaul/mixer_b16_sam_fe/1'
        elif self.type == 'sam_b32':
            url = 'https://tfhub.dev/sayakpaul/mixer_b32_sam_fe/1'
        else:
            raise ValueError(f'Unknown version {self.type} for '
                            f'{self.__class__.__name__.lower()}.')

        model = tf.keras.Sequential([
            hub.KerasLayer(url)
        ])
        return model, None

    def forward_truncated(self, input):
        # Convert PT tensor to TF
        input = tf.convert_to_tensor(input.numpy())

        # Forward
        output = self.truncated_model(input)[0]

        # Convert back TF tensor to PT
        return torch.tensor(output.numpy())

    def forward_whole(self, input):
        # Convert PT tensor to TF
        input = tf.convert_to_tensor(input.numpy())

        # Forward
        output = self.model(input)[0]

        # Convert back TF tensor to PT
        return torch.tensor(output.numpy())
