import copy
import logging
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.datasets import make_circles
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

import utils
from data.misc import MultimodalDataset
from hyperparams.load import get_config

logger = logging.getLogger('custom')
config = get_config()


class SyntheticDataset(MultimodalDataset):
    def __init__(self, x1, x2, y, m):
        super().__init__(x1, x2, y)
        self.s['m'] = torch.tensor(m, dtype=torch.float)
        assert len(self.s['m']) == self.len

        self.color_dict = {}
        self.define_color_dict()

    def define_color_dict(self, version='colorbrewer'):
        """ Allows visualizing data with consistent colormap. """
        if version == 'trending':
            # https://coolors.co/palettes/trending
            color_dict = {0: '#264653',
                          1: '#2A9D8F',
                          3: '#E9C46A',
                          4: '#F4A261',
                          None: '#E76F51'}
        elif version == 'colorbrewer':
            # https://colorbrewer2.org/#type=qualitative&scheme=Dark2&n=8
            color_dict = {0: '#1b9e77',
                          1: '#d95f02',
                          2: '#7570b3',
                          3: '#e7298a',
                          4: '#66a61e',
                          5: '#e6ab02',
                          6: '#a6761d',
                          None: '#666666'}
        elif version == 'custom':
            # custom
            color_dict = {0: '#FFA630',
                          1: '#588B8B',
                          None: '#3089ff'}
        else:
            raise ValueError('Please reference existing color_dict')

        self.color_dict = {'all': color_dict}


def load_synthetic_data(mode, batch_size=64):
    """
    :param mode: split
    """
    # ----------- GET TENSORS -----------
    synthetic_dir = config.dirs['synthetic_data']
    data_dir = os.path.join(synthetic_dir, f'synthetic_data.pkl')

    if os.path.isfile(data_dir):
        print(f'\nUsing existing synthetic data from:\n{data_dir}\n')
        tensors = torch.load(data_dir)
    else:
        print('\nHave not found existing synthetic data, creating new one at:\n'
              f'{data_dir}\n')
        tensors = _create_tensors()
        os.makedirs(synthetic_dir, exist_ok=True)
        torch.save(tensors, data_dir)

    # ----------- CREATE DATASETS AND LOADERS -----------
    datasets = {}
    for k, v in tensors.items():
        x, y, m = v
        x1, x2 = x
        x1 = utils.to_torch(x1, dtype=torch.float)
        dtype = torch.int64 if len(x2.shape) == 1 else torch.float
        x2 = utils.to_torch(x2, dtype=dtype)
        datasets[k] = SyntheticDataset(x1=x1, x2=x2, y=y, m=m)

    msg = f'Mode must be "train" or "test", but is {mode}. ' \
          f'Note that there is no validation set for synthetic data.'
    assert mode in ['train', 'test'], msg

    dataset = datasets[mode]
    loader = DataLoader(dataset, batch_size,
                        shuffle=mode == 'train',
                        pin_memory=True)

    return dataset, loader


def _create_tensors():
    """ Creates tensors which can populate a dataset. """
    x1, y, m = _create_modality('four_wheels', relabel=True)
    # labels are second modality
    x1 = _scale(x1)
    x = [x1, y]

    tensors = x, y, m
    tensors = _split_tensors(tensors)

    return tensors


def _split_tensors(tensors):
    """ Splits tensors into 'train' and 'test'. """
    x, y, m = tensors

    idx = np.arange(y.shape[0])
    idx_train, idx_test, y_train, y_test = train_test_split(idx, y,
                                                            test_size=0.2)

    tensors = {'train': ([v[idx_train] for v in x],
                         y_train, m[idx_train]),
               'test': ([v[idx_test] for v in x],
                        y_test, m[idx_test])}

    _plot(tensors['train'], title='train')

    return tensors


def _create_modality(name, **kwargs):
    """ Creates two-dimensional datapoints that can be used as a modality.
    :param name: name of element
    """
    # supplementary info about clusters (if there are multiple clusters per
    # class)
    m = None

    if name == 'four_wheels':
        x, y, m = _create_pinwheel_data(4, 2000, **kwargs)
    elif name == 'circles':
        x, y = make_circles(n_samples=8000, shuffle=True, noise=0.05,
                            random_state=None,
                            factor=0.4)
    elif name == 'gauss_in_rec':
        x, y = _create_gaussian_in_rec(num_samples=4000)
    else:
        raise NotImplementedError

    return x, y, m


def _create_pinwheel_data(num_classes, num_per_class, radial_std=0.3,
                          tangential_std=0.05, rate=0.25, relabel=True):
    """ Creates four "pinwheels" around origin, which are alternately
    labeled. All elements are slightly shifted away from the origin.
    """

    def relabel_function(y, num_classes):
        """ Alternately label classes. """
        new_y = 0
        for idx, cur_y in enumerate(np.unique(y)):
            if idx == num_classes:
                new_y = 0
            idx = (y == cur_y)
            y[idx] = new_y
            new_y += 1
        return y

    # the following is inspired by
    # https://github.com/mattjj/svae/blob/master/experiments/gmm_svae_synth.py#L11
    rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)

    features = np.random.randn(num_classes * num_per_class, 2) * np.array(
        [radial_std, tangential_std])
    features[:, 0] += 1.
    labels = np.repeat(np.arange(num_classes), num_per_class)

    angles = rads[labels] + rate * np.exp(features[:, 0])
    rotations = np.stack(
        [np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
    rotations = np.reshape(rotations.T, (-1, 2, 2))

    idx = np.random.permutation(np.arange(num_classes * num_per_class))
    x = 10 * np.einsum('ti,tij->tj', features, rotations)[idx]
    y = labels[idx]

    # move clusters further apart
    x[y == 0] += [0, -10]  # south
    x[y == 1] += [-10, 0]  # west
    x[y == 2] += [0, 10]  # north
    x[y == 3] += [10, 0]  # east

    m = copy.deepcopy(y)  # indicates clusters within classes
    if relabel:
        y = relabel_function(y, num_classes=2)

    return x, y, m


def _create_gaussian_in_rec(num_samples):
    """ Creates datapoints that resemble a gaussian distribution surrounded
    by a rectangle in two-dimensional space. """
    x = []
    shape = (int(num_samples / 5), 2)

    # ----------- OUTER RECTANGLE -----------
    # left
    cur_x = np.random.normal(0, 0.5, shape)
    cur_x[:, 0] *= 0.1
    cur_x[:, 0] -= 1
    x.append(cur_x)

    # right
    cur_x = np.random.normal(0, 0.5, shape)
    cur_x[:, 0] *= 0.1
    cur_x[:, 0] += 1
    x.append(cur_x)

    # bottom
    cur_x = np.random.normal(0, 0.5, shape)
    cur_x[:, 1] *= 0.1
    cur_x[:, 1] -= 1
    x.append(cur_x)

    # top
    cur_x = np.random.normal(0, 0.5, shape)
    cur_x[:, 1] *= 0.1
    cur_x[:, 1] += 1
    x.append(cur_x)

    # ----------- CENTER -----------
    cur_x = np.random.normal(0, 0.1, shape)
    x.append(cur_x)

    x = np.concatenate(x)
    y = np.zeros(x.shape[0])
    return x, y


def _scale(x, t=5):
    """ Scale dataset to [-t, t], i.e. avoid large positive or negative
    numbers. """
    if x is None:
        return
    else:
        max_value = max(np.abs(x.max()), np.abs(x.min()))
        return (x / max_value) * t


def _plot(tensors, title=None):
    colors = ['b', 'g', 'r', 'm']
    x, y, _ = tensors
    if len(x) == 1:
        x1 = x[0]
        x2 = None
    elif len(x) == 2:
        x1, x2 = x
    else:
        raise ValueError('currently only two modalities supported')

    # modality x1
    if title:
        plt.title(f'{title}_modality_x1')
    for cur_y, c in zip(np.unique(y), colors):
        idx = (y == cur_y)
        plt.scatter(x1[idx][:, 0], x1[idx][:, 1], color=c)
    plt.show()

    # modality x2
    if x2 is not None and len(x2.shape) > 1:
        if title:
            plt.title(f'{title}_modality_x2')
        for cur_y, c in zip(np.unique(y), colors):
            idx = (y == cur_y)
            plt.scatter(x2[idx][:, 0], x2[idx][:, 1], s=200, marker='X',
                        color=c)
    plt.show()


if __name__ == '__main__':
    datasets, loaders = load_synthetic_data(mode='train', batch_size=64)
