import torch
from torch.nn import TripletMarginLoss
from torch.nn.functional import cross_entropy

from sklearn.neighbors import NearestNeighbors

from collections import namedtuple

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from rdt.transformers import ClusterBasedNormalizer, OneHotEncoder

SpanInfo = namedtuple('SpanInfo', ['dim', 'activation_fn'])
ColumnTransformInfo = namedtuple(
    'ColumnTransformInfo', [
        'column_name', 'column_type', 'transform', 'output_info', 'output_dimensions'
    ]
)


class DataTransformer(object):
    """Data Transformer.

    Model continuous columns with a BayesianGMM and normalize them to a scalar between [-1, 1]
    and a vector. Discrete columns are encoded using a OneHotEncoder.
    """

    def __init__(self, max_clusters=10, weight_threshold=0.005):
        """Create a data transformer.

        Args:
            max_clusters (int):
                Maximum number of Gaussian distributions in Bayesian GMM.
            weight_threshold (float):
                Weight threshold for a Gaussian distribution to be kept.
        """
        self._max_clusters = max_clusters
        self._weight_threshold = weight_threshold

    def _fit_continuous(self, data):
        """Train Bayesian GMM for continuous columns.

        Args:
            data (pd.DataFrame):
                A dataframe containing a column.

        Returns:
            namedtuple:
                A ``ColumnTransformInfo`` object.
        """
        column_name = data.columns[0]
        gm = ClusterBasedNormalizer(
            missing_value_generation='from_column',
            max_clusters=min(len(data), self._max_clusters),
            weight_threshold=self._weight_threshold
        )
        gm.fit(data, column_name)
        num_components = sum(gm.valid_component_indicator)

        return ColumnTransformInfo(
            column_name=column_name, column_type='continuous', transform=gm,
            output_info=[SpanInfo(1, 'tanh'), SpanInfo(num_components, 'softmax')],
            output_dimensions=1 + num_components)

    def _fit_discrete(self, data):
        """Fit one hot encoder for discrete column.

        Args:
            data (pd.DataFrame):
                A dataframe containing a column.

        Returns:
            namedtuple:
                A ``ColumnTransformInfo`` object.
        """
        column_name = data.columns[0]
        ohe = OneHotEncoder()
        ohe.fit(data, column_name)
        num_categories = len(ohe.dummies)

        return ColumnTransformInfo(
            column_name=column_name, column_type='discrete', transform=ohe,
            output_info=[SpanInfo(num_categories, 'softmax')],
            output_dimensions=num_categories)

    def fit(self, raw_data, discrete_columns=()):
        """Fit the ``DataTransformer``.

        Fits a ``ClusterBasedNormalizer`` for continuous columns and a
        ``OneHotEncoder`` for discrete columns.

        This step also counts the #columns in matrix data and span information.
        """
        self.output_info_list = []
        self.output_dimensions = 0
        self.dataframe = True

        if not isinstance(raw_data, pd.DataFrame):
            self.dataframe = False
            # work around for RDT issue #328 Fitting with numerical column names fails
            discrete_columns = [str(column) for column in discrete_columns]
            column_names = [str(num) for num in range(raw_data.shape[1])]
            raw_data = pd.DataFrame(raw_data, columns=column_names)

        self._column_raw_dtypes = raw_data.infer_objects().dtypes
        self._column_transform_info_list = []
        for column_name in raw_data.columns:
            if column_name in discrete_columns:
                column_transform_info = self._fit_discrete(raw_data[[column_name]])
            else:
                column_transform_info = self._fit_continuous(raw_data[[column_name]])

            self.output_info_list.append(column_transform_info.output_info)
            self.output_dimensions += column_transform_info.output_dimensions
            self._column_transform_info_list.append(column_transform_info)

    def _transform_continuous(self, column_transform_info, data):
        column_name = data.columns[0]
        flattened_column = data[column_name].to_numpy().flatten()
        data = data.assign(**{column_name: flattened_column})
        gm = column_transform_info.transform
        transformed = gm.transform(data)

        #  Converts the transformed data to the appropriate output format.
        #  The first column (ending in '.normalized') stays the same,
        #  but the lable encoded column (ending in '.component') is one hot encoded.
        output = np.zeros((len(transformed), column_transform_info.output_dimensions))
        output[:, 0] = transformed[f'{column_name}.normalized'].to_numpy()
        index = transformed[f'{column_name}.component'].to_numpy().astype(int)
        output[np.arange(index.size), index + 1] = 1.0

        return output

    def _transform_discrete(self, column_transform_info, data):
        ohe = column_transform_info.transform
        return ohe.transform(data).to_numpy()

    def _synchronous_transform(self, raw_data, column_transform_info_list):
        """Take a Pandas DataFrame and transform columns synchronous.

        Outputs a list with Numpy arrays.
        """
        column_data_list = []
        for column_transform_info in column_transform_info_list:
            column_name = column_transform_info.column_name
            data = raw_data[[column_name]]
            if column_transform_info.column_type == 'continuous':
                column_data_list.append(self._transform_continuous(column_transform_info, data))
            else:
                column_data_list.append(self._transform_discrete(column_transform_info, data))

        return column_data_list

    def _parallel_transform(self, raw_data, column_transform_info_list):
        """Take a Pandas DataFrame and transform columns in parallel.

        Outputs a list with Numpy arrays.
        """
        processes = []
        for column_transform_info in column_transform_info_list:
            column_name = column_transform_info.column_name
            data = raw_data[[column_name]]
            process = None
            if column_transform_info.column_type == 'continuous':
                process = delayed(self._transform_continuous)(column_transform_info, data)
            else:
                process = delayed(self._transform_discrete)(column_transform_info, data)
            processes.append(process)

        return Parallel(n_jobs=-1)(processes)

    def transform(self, raw_data):
        """Take raw data and output a matrix data."""
        if not isinstance(raw_data, pd.DataFrame):
            column_names = [str(num) for num in range(raw_data.shape[1])]
            raw_data = pd.DataFrame(raw_data, columns=column_names)

        # Only use parallelization with larger data sizes.
        # Otherwise, the transformation will be slower.
        if raw_data.shape[0] < 500:
            column_data_list = self._synchronous_transform(
                raw_data,
                self._column_transform_info_list
            )
        else:
            column_data_list = self._parallel_transform(
                raw_data,
                self._column_transform_info_list
            )

        return np.concatenate(column_data_list, axis=1).astype(float)

    def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st):
        gm = column_transform_info.transform
        data = pd.DataFrame(column_data[:, :2], columns=list(gm.get_output_sdtypes()))
        data[data.columns[1]] = np.argmax(column_data[:, 1:], axis=1)
        if sigmas is not None:
            selected_normalized_value = np.random.normal(data.iloc[:, 0], sigmas[st])
            data.iloc[:, 0] = selected_normalized_value

        return gm.reverse_transform(data)

    def _inverse_transform_discrete(self, column_transform_info, column_data):
        ohe = column_transform_info.transform
        data = pd.DataFrame(column_data, columns=list(ohe.get_output_sdtypes()))
        return ohe.reverse_transform(data)[column_transform_info.column_name]

    def inverse_transform(self, data, sigmas=None):
        """Take matrix data and output raw data.

        Output uses the same type as input to the transform function.
        Either np array or pd dataframe.
        """
        st = 0
        recovered_column_data_list = []
        column_names = []
        for column_transform_info in self._column_transform_info_list:
            dim = column_transform_info.output_dimensions
            column_data = data[:, st:st + dim]
            if column_transform_info.column_type == 'continuous':
                recovered_column_data = self._inverse_transform_continuous(
                    column_transform_info, column_data, sigmas, st)
            else:
                recovered_column_data = self._inverse_transform_discrete(
                    column_transform_info, column_data)

            recovered_column_data_list.append(recovered_column_data)
            column_names.append(column_transform_info.column_name)
            st += dim

        recovered_data = np.column_stack(recovered_column_data_list)
        recovered_data = (pd.DataFrame(recovered_data, columns=column_names)
                          .astype(self._column_raw_dtypes))
        if not self.dataframe:
            recovered_data = recovered_data.to_numpy()

        return recovered_data

    def convert_column_name_value_to_id(self, column_name, value):
        """Get the ids of the given `column_name`."""
        discrete_counter = 0
        column_id = 0
        for column_transform_info in self._column_transform_info_list:
            if column_transform_info.column_name == column_name:
                break
            if column_transform_info.column_type == 'discrete':
                discrete_counter += 1

            column_id += 1

        else:
            raise ValueError(f"The column_name `{column_name}` doesn't exist in the data.")

        ohe = column_transform_info.transform
        data = pd.DataFrame([value], columns=[column_transform_info.column_name])
        one_hot = ohe.transform(data).to_numpy()[0]
        if sum(one_hot) == 0:
            raise ValueError(f"The value `{value}` doesn't exist in the column `{column_name}`.")

        return {
            'discrete_column_id': discrete_counter,
            'column_id': column_id,
            'value_id': np.argmax(one_hot)
        }



def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std


def _loss_function_MMD(recon_x, x, sigmas, mean, std, output_info, factor,kernel_choice='rbf'):
    st = 0
    loss = []
    for column_info in output_info:
        for span_info in column_info:
            if span_info.activation_fn != 'softmax':
                ed = st + span_info.dim
                std = sigmas[st]
                eq = x[:, st] - torch.tanh(recon_x[:, st])
                loss.append((eq ** 2 / 2 / (std ** 2)).sum())
                loss.append(torch.log(std) * x.size()[0])
                st = ed

            else:
                ed = st + span_info.dim
                loss.append(cross_entropy(
                    recon_x[:, st:ed], torch.argmax(x[:, st:ed], dim=-1), reduction='sum'))
                st = ed

    assert st == recon_x.size()[1]

    eps = torch.randn_like(std)
    z = eps * std + mean

    N = z.shape[0]

    z_prior = torch.randn_like(z)#.to(device)

    if kernel_choice == "rbf":
        k_z = rbf_kernel(z, z)
        k_z_prior = rbf_kernel(z_prior, z_prior)
        k_cross = rbf_kernel(z, z_prior)

    else:
        k_z = imq_kernel(z, z)
        k_z_prior = imq_kernel(z_prior, z_prior)
        k_cross = imq_kernel(z, z_prior)

    mmd_z = (k_z - k_z.diag().diag()).sum() / ((N - 1) * N)
    mmd_z_prior = (k_z_prior - k_z_prior.diag().diag()).sum() / ((N - 1) * N)
    mmd_cross = k_cross.sum() / (N ** 2)

    mmd_loss = mmd_z + mmd_z_prior - 2 * mmd_cross

    return sum(loss) * factor / x.size()[0] + mmd_loss

def imq_kernel(z1, z2):
    """Returns a matrix of shape [batch x batch] containing the pairwise kernel computation"""
    kernel_bandwidth = 1.0
    scales = 1.0
    latent_dim=z1.shape[1]
    Cbase = (2.0 * latent_dim * kernel_bandwidth ** 2)

    k = 0

    for scale in scales:
        C = scale * Cbase
        k += C / (C + torch.norm(z1.unsqueeze(1) - z2.unsqueeze(0), dim=-1) ** 2)

    return k

def rbf_kernel(z1, z2):
    """Returns a matrix of shape [batch x batch] containing the pairwise kernel computation"""
    kernel_bandwidth = 1.0
    latent_dim=z1.shape[1]
    C = 2.0 * latent_dim * kernel_bandwidth ** 2

    k = torch.exp(-torch.norm(z1.unsqueeze(1) - z2.unsqueeze(0), dim=-1) ** 2 / C)

    return k


def z_gen(embeddings,n_to_sample,metric='minkowski',interpolation_method='SMOTE'):
    # fitting the model
    n_neighbors = 5 +1
    nn = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=1,metric= metric)
    nn.fit(embeddings)
    dist, ind = nn.kneighbors(embeddings)

    # generating samples
    base_indices = np.random.choice(list(range(len(embeddings))),n_to_sample)
    neighbor_indices = np.random.choice(list(range(1, n_neighbors)),n_to_sample)

    embeddings_base = embeddings[base_indices]

    if interpolation_method =='SMOTE': ## randomly generate synthetic latent point between 2 real latent points
      embeddings_neighbor = embeddings[ind[base_indices, neighbor_indices]]
      deviations = np.multiply(np.random.rand(n_to_sample,1), embeddings_neighbor - embeddings_base)

      embeddings_samples = embeddings_base + deviations

    elif interpolation_method == 'rectangle':
      embeddings_neighbor = embeddings[ind[base_indices, neighbor_indices]]
      embeddings_samples = np.random.uniform()*embeddings_neighbor + (1-np.random.uniform())*embeddings_base

    elif interpolation_method == 'triangle': ## permutation all latent points in one neighborhood based on their inverse distance as weight
      deviations=0
      for i in range(1,n_neighbors):
        embeddings_neighbor = embeddings[ind[base_indices, i]]
        weight = (n_neighbors -i)/(n_neighbors*(n_neighbors-1)/2)
        deviation = np.multiply(np.random.rand(n_to_sample,1), embeddings_neighbor - embeddings_base)
        deviations += weight*deviation
        embeddings_samples = embeddings_base + deviations

    else:
      mean = torch.zeros(n_to_sample, embeddings.shape[1])
      std = mean + 1
      embeddings_samples = reparameterize(mean, std)

    return embeddings_samples


def triplet_loss_margin(mean, batch_labels, factor, margin):
    triplet_loss_fn = TripletMarginLoss(margin=margin, p=2)
    anchors, positives, negatives = [], [], []
    valid_triplets = 0

    with torch.no_grad():
        dists = torch.cdist(mean, mean, p=2)

    for i in range(len(batch_labels)):
        # get anchor embedding and its label
        label_i = batch_labels[i].item()
        # we use mean because it is the central representation of the input in the latent space
        anchor = mean[i]

        # get indices of same-class points and different-class points
        pos_idx = (batch_labels == label_i).nonzero(as_tuple=True)[0]
        neg_idx = (batch_labels != label_i).nonzero(as_tuple=True)[0]

        pos_sample = pos_idx[pos_idx != i]  # exclude self
        if len(pos_sample) == 0 or len(neg_idx) == 0:
            continue

        # Get distances to anchor
        dist_to_anchor_pos = dists[i][pos_sample]
        min_pos_dist = dist_to_anchor_pos.min()
        dist_to_anchor_neg = dists[i][neg_idx]

        # Pick semi-hard negatives: d(an) > d(ap) but d(an) < d(ap) + margin
        semi_hard_mask = (dist_to_anchor_neg > min_pos_dist) & (dist_to_anchor_neg < min_pos_dist + margin)
        semi_hard_points = neg_idx[semi_hard_mask]

        # Select furthest (hardest) positive to maximize loss
        pos_j = pos_sample[torch.argmax(dist_to_anchor_pos)]
        positive = mean[pos_j]

        # Select semi-hard negative if exists, else fallback to closest (hardest) negative
        if len(semi_hard_points) > 0:
            neg_j = semi_hard_points[torch.randint(len(semi_hard_points), (1,)).item()]
            negative = mean[neg_j]
        else:
            neg_j = neg_idx[torch.argmin(dist_to_anchor_neg)]
            negative = mean[neg_j]

        anchors.append(anchor)
        positives.append(positive)
        negatives.append(negative)
        valid_triplets += 1

    if valid_triplets > 0:
        loss = triplet_loss_fn(
            torch.stack(anchors),
            torch.stack(positives),
            torch.stack(negatives)
        )
        return factor * loss
    else:
        return torch.tensor(0.0, device=mean.device)