import numpy as np
import torch


def pairwise_distances(X):
    """
    Compute the matrix of pairwise distances between all samples in X.
    X shape: [n, d]
    Returns a distance matrix of shape [n, n].
    """
    # X^2 sum
    # shape: (n, 1)
    sum_x = np.sum(X * X, axis=1, keepdims=True)
    # dists[i, j] = ||X[i] - X[j]||^2
    dists = sum_x + sum_x.T - 2 * X @ X.T
    return dists


class DatasetEmbDataset(torch.utils.data.Dataset):
    """PyOD Dataset class for PyTorch Dataloader
    """

    def __init__(self, data_names, pe_path, visual_path, normalize=False, y_mu=None, y_std=None,
                 z_cali='none', z_anchor=None, require_zdist=True):
        super(DatasetEmbDataset, self).__init__()
        self.inp_data = []
        self.out_data = []
        self.y = []
        self.block_size = []
        self.hps = []
        self.zdist= []

        self.require_zdist = require_zdist
        method = 'TSNE'

        self.z_cali_method = z_cali
        for dataset in data_names:
            pe, y = torch.load(pe_path + '/' + f'{dataset}_pe_3.pth')
            # pe, y = torch.load(pe_path + '/' + f'{dataset}_pe_5_weight.pth')
            self.inp_data.append(pe)
            self.y.append(y)
            selected_emb, hps = torch.load(visual_path + '/' + f'visual-method-{method}_dataset-{dataset}_selected_emb.tar')

            selected_emb = self.calibrate_z(selected_emb, z_anchor=z_anchor)

            self.out_data.append(selected_emb)
            self.hps.append(hps)
            self.block_size.append(pe.shape[0])

        # self.inp_data = np.concatenate(self.inp_data)
        # self.out_data = np.concatenate(self.out_data)
        # self.y = np.concatenate(self.y)

        # self.inp_data[np.isnan(self.inp_data)] = 0

        if y_mu is None and y_std is None:
            # self.mu = np.mean(self.inp_data, axis=0)
            # self.std = np.std(self.inp_data, axis=0)
            _out_data = np.concatenate(self.out_data)
            self.y_mu = np.mean(_out_data, axis=0)
            self.y_std = np.std(_out_data, axis=0)
        else:
            # self.mu = mu
            # self.std = std
            self.y_mu = y_mu
            self.y_std = y_std

        if normalize:
            self.out_data = np.concatenate(self.out_data)
            # self.inp_data = (self.inp_data - self.mu) / (self.std + 1e-9)
            self.out_data = (self.out_data - self.y_mu) / (self.y_std + 1e-9)
            split_indices = np.cumsum(self.block_size)[:-1]
            self.out_data = np.split(self.out_data, split_indices, axis=0)

        if require_zdist:
            for i in range(len(self.out_data)):
                self.zdist.append(pairwise_distances(self.out_data[i]))
            pass

    def __len__(self):
        return len(self.inp_data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.inp_data[idx]
        sample_y = self.out_data[idx]
        # sample_torch = torch.from_numpy(sample)
        if self.require_zdist:
            z_dist = self.zdist[idx]
            return torch.from_numpy(sample), torch.from_numpy(sample_y), torch.from_numpy(z_dist), idx
        else:
            return torch.from_numpy(sample), torch.from_numpy(sample_y), idx

    def calibrate_z(self, z, z_anchor, n_dim=2):
        if self.z_cali_method == 'svd':
            U, S, Vt = np.linalg.svd(z)
            S_matrix = np.diag(S)
            new_z = np.dot(U[:, :n_dim], S_matrix[:n_dim, :n_dim])
            return new_z
        elif self.z_cali_method == 'orth':
            m = np.matmul(np.mean(z, axis=0).reshape(2, 1), np.mean(z_anchor, axis=0).reshape(1, 2))
            u, s, vt = np.linalg.svd(m)
            r = np.matmul(u, vt)
            new_z = z @ r
            return new_z
        elif self.z_cali_method == 'max_trace':
            m = np.matmul(z.T, z_anchor)
            u, s, vt = np.linalg.svd(m)
            r = np.matmul(vt.T, u.T)
            # if np.linalg.det(r) < 0:
            #     vt[:, 0] = vt[:, 0] * -1
            #     r = np.matmul(vt.T, u.T)
            new_z = z @ r
            return new_z

        elif self.z_cali_method == 'none':
            return z
        else:
            raise NotImplementedError


if __name__ == "__main__":
    ds = DatasetEmbDataset(data_names=['mnist_group1',],
                           pe_path='./prepare_data/pe/pe_data', visual_path='./prepare_data/bo/res-2')

