# /usr/bin/env python
# -*- coding: utf-8 -*-

from os.path import join

import torch
import torch.utils.data as data

from utils.shift_simulate import *

def normalize(data):
    shape = data.shape

    data = data.view(shape[0], -1)
    min_value = data.min(1, keepdim=True)[0]
    max_value = data.max(1, keepdim=True)[0]
    data = (data - min_value) / (max_value - min_value)
    data = data.view(shape)

    return data


class ToyTrainSet(data.Dataset):
    def __init__(self, cfgs, rng=None):

        if rng is None:
            rng = np.random.default_rng()
        self.rng = rng

        if 'data' in cfgs:
            data = cfgs
        else:
            path = cfgs['path']
            data = torch.load(path)

        self.X_bak, self.label = data['data'], data['label']
        self.data = torch.from_numpy(self.X_bak).float()
        self.is_cls = True

        if self.is_cls:
            self.info = {
                'is_cls': True,
                'cls_num': data['cls_num'],
                'dim': data['dim'],
                'radius': data['radius'],
                'centers': data['centers'],
                'covs': data['covs'],
                'priors': data['priors'],
                'normalize': data['normalize'],
            }
        else:
            self.info = {
                'is_cls': False,
                'cls_num': data['cls_num'],
                'dim': data['dim'],
                'amp': data['amp'],
                'omega': data['omega'],
                'phase': data['phase'],
                'offset': data['offset'],
                'label_noise': data['label_noise']
            }

    def __getitem__(self, index):
        return self.data[index], self.label[index], index

    def __len__(self):
        return len(self.data)

    def save(self, output):
        if self.is_cls:
            torch.save({
                'data': self.X_bak,
                'label': self.label,
                'is_cls': True,
                'cls_num': self.info['cls_num'],
                'dim': self.info['dim'],
                'radius': self.info['radius'],
                'centers': self.info['centers'],
                'covs': self.info['covs'],
                'priors': self.info['priors'],
                'normalize': self.info['normalize'],
            }, output)
        else:
            torch.save({
                'data': self.X_bak,
                'label': self.label,
                'is_cls': False,
                'dim': self.info['dim'],
                'cls_num': self.info['cls_num'],
                'amp': self.info['amp'],
                'omega': self.info['omega'],
                'phase': self.info['phase'],
                'offset': self.info['offset'],
                'label_noise': self.info['label_noise']
            }, output)


class ToyTestSet(data.Dataset):
    def __init__(self, cfgs, info, rng=None):
        if rng is None:
            rng = np.random.default_rng()
        self.rng = rng

        self.info = info
        self.is_cls = info['is_cls']
        self.shift_gen = eval(cfgs['shift']['type'])(rng=self.rng,
                                                     **cfgs['shift']['kwargs'])
        self.batch_size = cfgs['per_round_num']
        self.eval_size = 10000
        self.cfgs = cfgs
        self.radius = info['radius']

    def __getitem__(self, t):
        data, label_noise, priors = None, None, None

        if self.is_cls:
            priors = self.shift_gen(t)
            centers = self.info['centers']
            covs = self.info['covs']

            center0_num = round(self.batch_size * priors[0])
            center1_num = self.batch_size - center0_num

            center0_data_p = np.random.multivariate_normal(centers[0], covs[0], size=int(center0_num / 2))
            center0_label_p = np.ones([int(center0_num / 2), 1])
            center0_data_n = np.random.multivariate_normal(centers[1], covs[0], size=center0_num - int(center0_num / 2))
            center0_label_n = -1 * np.ones([center0_num - int(center0_num / 2), 1])

            center1_data_p = np.random.multivariate_normal(centers[2], covs[1], size=int(center1_num / 2))
            center1_label_p = np.ones([int(center1_num / 2), 1])
            center1_data_n = np.random.multivariate_normal(centers[3], covs[1], size=center1_num - int(center1_num / 2))
            center1_label_n = -1 * np.ones([center1_num - int(center1_num / 2), 1])

            data = np.concatenate((center0_data_p, center0_data_n, center1_data_p, center1_data_n))
            label = np.concatenate((center0_label_p, center0_label_n, center1_label_p, center1_label_n))
            label = label > 0
            label = label.flatten()

        else:
            omega = self.info['omega']
            amp = self.info['amp']
            phase = self.info['phase']
            offset = self.info['offset']
            label_noise = self.info['label_noise']

            x_ = self.shift_gen(t)

            data = np.linspace(-5., 5., 10000)
            label = amp * np.sin(omega * data + phase) + offset
            label = label + label_noise * (self.rng.random(len(data)) - 0.5)
            data = data[x_[0]:x_[1]]
            label = label[x_[0]:x_[1]]

        return data, label, None

    def get_eval(self, t):
        data, label_noise, priors = None, None, None

        if self.is_cls:
            priors = self.shift_gen(t)
            centers = self.info['centers']
            covs = self.info['covs']

            center0_num = round(self.eval_size * priors[0])
            center1_num = self.eval_size - center0_num

            center0_data_p = np.random.multivariate_normal(centers[0], covs[0], size=int(center0_num / 2))
            center0_label_p = np.ones([int(center0_num / 2), 1])
            center0_data_n = np.random.multivariate_normal(centers[1], covs[0], size=center0_num - int(center0_num / 2))
            center0_label_n = -1 * np.ones([center0_num - int(center0_num / 2), 1])

            center1_data_p = np.random.multivariate_normal(centers[2], covs[1], size=int(center1_num / 2))
            center1_label_p = np.ones([int(center1_num / 2), 1])
            center1_data_n = np.random.multivariate_normal(centers[3], covs[1], size=center1_num - int(center1_num / 2))
            center1_label_n = -1 * np.ones([center1_num - int(center1_num / 2), 1])

            data = np.concatenate((center0_data_p, center0_data_n, center1_data_p, center1_data_n))
            label = np.concatenate((center0_label_p, center0_label_n, center1_label_p, center1_label_n))
            label = label > 0
            label = label.flatten()

        else:
            omega = self.info['omega']
            amp = self.info['amp']
            phase = self.info['phase']
            offset = self.info['offset']
            label_noise = self.info['label_noise']

            x_ = self.shift_gen(t)

            data = np.linspace(-5., 5., 10000)
            label = amp * np.sin(omega * data + phase) + offset
            label = label + label_noise * (self.rng.random(len(data)) - 0.5)
            data = data[x_[0]:x_[1]]
            label = label[x_[0]:x_[1]]

        return data, label, None

def generate_offline_data(rng, train_num, dim, centers, covs, source_priors, radius, cls_num=2, output=None,
                          fname='offline_toy_data.pt'):

    center0_num = round(train_num * source_priors[0])
    center1_num = train_num - center0_num

    center0_data_p = np.random.multivariate_normal(centers[0], covs[0], size=int(center0_num/2))
    center0_label_p = np.ones([int(center0_num/2), 1])
    center0_data_n = np.random.multivariate_normal(centers[1], covs[0], size=center0_num - int(center0_num/2))
    center0_label_n = -1 * np.ones([center0_num - int(center0_num/2), 1])

    center1_data_p = np.random.multivariate_normal(centers[2], covs[1], size=int(center1_num/2))
    center1_label_p = np.ones([int(center1_num/2), 1])
    center1_data_n = np.random.multivariate_normal(centers[3], covs[1], size=center1_num - int(center1_num/2))
    center1_label_n = -1 * np.ones([center1_num - int(center1_num/2), 1])

    data = np.concatenate((center0_data_p, center0_data_n, center1_data_p, center1_data_n))
    label = np.concatenate((center0_label_p, center0_label_n, center1_label_p, center1_label_n))
    label = label>0
    label = label.flatten()

    print(sum(label>0))

    train_cfgs = {
        'is_cls': True,
        'data': data,
        'label': label,
        'radius': radius,
        'centers': centers,
        'covs': covs,
        'priors': source_priors,
        'normalize': True,
        'dim': dim,
        'cls_num': cls_num
    }

    train_set = ToyTrainSet(train_cfgs, rng=rng)

    if output is not None:
        train_set.save(join(output, fname))

    return None