from __future__ import annotations
from typing import Tuple
from src.source_target_dataset import ImageDataset
from src.distance import W2_dist
from src.cityscapes_tools import CLASSES
import numpy as np
import ot
from geomloss import SamplesLoss
import torch
from torch.autograd import Variable
from torch.nn import MSELoss
from sklearn.cluster import KMeans
from typing import Optional
from sklearn.utils import check_random_state
from sklearn.cluster import KMeans
from sklearn.neighbors import KNeighborsClassifier


class OTModel():
    def __init__(self):
        self._model = ot.da.EMDTransport()

    def fit(self, X: np.ndarray, Xt: np.ndarray) -> OTModel:
        self._target = Xt.copy()
        self._model.fit(X, Xt=Xt)
        return self

    def transform(self, X: np.ndarray) -> np.ndarray:
        return self._model.transform(Xs=X, Xt=self._target)  # type: ignore

    def cluster_forward(self, X: np.ndarray, n_clusters: int, Y: Optional[np.ndarray] = None) -> np.ndarray:
        if Y is None:
            Y = self._target
        unconstrained_Z = self.transform(X)
        labels = self._pair_clustering(X, unconstrained_Z, n_clusters)
        Z_clusters = self._cluster_mean_transport(X, unconstrained_Z, labels)
        return Z_clusters  

    def _cluster_mean_transport(self, X: np.ndarray, Z: np.ndarray, labels: np.ndarray) -> np.ndarray:
        Z_clusters = X.copy()  # the final output of the cluster mean shift transport
        for cluster_idx in np.unique(labels):
            X_cluster = X[labels == cluster_idx]
            Z_cluster = Z[labels == cluster_idx]
            # since we are doing mean shift cluster transport,
            # C_z = C_x + mean_shift  (mean_shift = C_z_mu - C_x_mu)
            X_cluster_pushed = X_cluster - X_cluster.mean(axis=0) + Z_cluster.mean(axis=0)
            Z_clusters[labels == cluster_idx] = X_cluster_pushed
        return Z_clusters

    @staticmethod
    def _pair_clustering(X: np.ndarray, Z: np.ndarray, n_clusters: int, rng=None) -> np.ndarray:
        rng = check_random_state(rng)
        # Pairing X and Z
        XZ = np.concatenate((X, Z), 1)
        XZ_km = KMeans(n_clusters, init='k-means++', random_state=rng).fit(XZ)
        XZ_labels = XZ_km.predict(XZ)
        return XZ_labels


# Adapted from https://github.com/inouye-lab/explaining-distribution-shifts/blob/master/notebooks/wisconsin-cancer-experiment.ipynb
def iterative_unconstrained_feature_transport(source: np.ndarray, target: np.ndarray, T: OTModel, n_features: int, excluded=None, feature_names=None) -> np.ndarray:
    Z_OT = T.fit(source, target).transform(source)
    W2_X_Y = W2_dist(source, target)  # calculating the W2 distance before any transporting
    X_means = source.mean(axis=0)
    Y_means = target.mean(axis=0)
    diff = Z_OT - source
    if excluded:
        diff[:, excluded] = 0
    argsorted_diff = np.linalg.norm(diff, axis=0).argsort()[::-1]  # a feature-wise divergence array in decreasing order
    free_dim_mask = np.zeros(source.shape[1], dtype=bool)  # a mask where True mean that feature can be transported
    Z_constrained = source.copy()
    W2s = np.zeros(source.shape[1]+1)  # W_2^2 distances of each transport, including T=id
    W2s[0] = W2_X_Y
    W2_deltas = np.zeros(source.shape[1]+1)
    transport_costs = np.zeros(source.shape[1]+1)
    # free_dims_over_time = [[]]
    features = []
    for n_free in range(n_features):
        newest_free = argsorted_diff[n_free]  # selecting the next feature to be included in the transport
        free_dim_mask[newest_free] = True
        Z_constrained[:, newest_free] = Z_OT[:, newest_free]
        d_Z_Y = W2_dist(target, Z_constrained)
        # recording
        percent_change = 100*(W2_X_Y - d_Z_Y) / W2_X_Y  # given as a %
        W2_deltas[n_free+1] = percent_change
        W2s[n_free+1] = d_Z_Y
        if not feature_names:
            if newest_free < 19:
                features.append(CLASSES[newest_free] + " count")
            elif newest_free >= 19 and (newest_free - 19) % 2 == 0:
                features.append(CLASSES[(newest_free - 19) // 2] + " avg. brightness")
            elif newest_free >= 19 and (newest_free - 19) % 2 == 1:
                features.append(CLASSES[(newest_free - 19) // 2] + " stdev. brightness")
        else:
            features.append(feature_names[newest_free])
        # transport_costs[n_free+1] = calc_parsimony(X, Z_constrained)
        # free_dims_over_time.append(np.flatnonzero(free_dim_mask))
        print(f'For {n_free+1} free ({features}),\n\tW2: {d_Z_Y:.3f},\tTotal shift explained: {percent_change:.1f}%')
    return Z_constrained

def iterative_mean_shift_transport(source: np.ndarray, target: np.ndarray, n_features: int, feature_names=None) -> np.ndarray:
    W2_X_Y = W2_dist(source, target)  # calculating the W2 distance before any transporting
    source_means = source.mean(axis=0)
    target_means = target.mean(axis=0)
    mean_diff_sort = np.argsort(abs(source_means - target_means))[::-1]  # sorts in decreasing order
    free_dim_mask = np.zeros(source.shape[1], dtype=bool)  # a mask where True mean that feature can be transported
    Z_constrained = source.copy()

    W2s = np.zeros(source.shape[1]+1)  # W_2^2 distances of each transport, including T=id
    W2s[0] = W2_X_Y
    W2_deltas = np.zeros(source.shape[1]+1)
    free_dims_over_time = [[]]
    shift = np.zeros(source.shape[1])
    for n_free in range(n_features):
        newest_free = mean_diff_sort[n_free]  # selecting the next feature to be included in the transport
        free_dim_mask[newest_free] = True
        Z_constrained[:, newest_free] += target_means[newest_free] - source_means[newest_free]
        shift[newest_free] = target_means[newest_free] - source_means[newest_free]
        d_Z_Y = W2_dist(target, Z_constrained)
        # recording
        percent_change = 100*(W2_X_Y - d_Z_Y) / W2_X_Y  # given as a %
        W2_deltas[n_free+1] = percent_change
        W2s[n_free+1] = d_Z_Y
        free_dims_over_time.append(np.flatnonzero(free_dim_mask))
        shifted_by = target_means[free_dims_over_time[-1]] - source_means[free_dims_over_time[-1]]
        indicies_to_reverse_sort = shifted_by.argsort()[::-1].astype(int)   # [::-1] to sort in decreasing order
        feat_names = []
        for ind in free_dims_over_time[-1][indicies_to_reverse_sort]:
            if not feature_names:
                if ind < 19:
                    feat_names.append(CLASSES[ind] + " count")
                elif ind >= 19 and (ind - 19) % 2 == 0:
                    feat_names.append(CLASSES[(ind - 19) // 2] + " avg. brightness")
                elif ind >= 19 and (ind - 19) % 2 == 1:
                    feat_names.append(CLASSES[(ind - 19) // 2] + " stdev. brightness")
            else:
                feat_names.append(feature_names[ind])
        with np.printoptions(precision=2, suppress=True):
            print(f'For {n_free} free features:\n',
                  f'\tFeatures: {feat_names} have been shifted by: {shifted_by[indicies_to_reverse_sort]}\n', 
                  f'\tW2: {d_Z_Y:.3f},\tTotal shift explained: {percent_change:.1f}%')
    return Z_constrained, shift


def group_feature_transport(
    source: np.ndarray,
    target: np.ndarray,
    source_groups: np.ndarray,
    target_groups: np.ndarray,
    n_features: int,
    lr=0.02,
    iters=90) -> np.ndarray:

    labels_s = torch.from_numpy(source_groups)
    labels_t = torch.from_numpy(target_groups)

    # Remove groups with less than 50 members
    features_keep = []
    for f in range(labels_s.shape[1]):
        if torch.sum(labels_t[:, f] == 1) < 1 or torch.sum(labels_s[:, f] == 1) < 1:
            continue
        features_keep.append(f)
    print(features_keep)
    labels_s = labels_s[:, features_keep]
    labels_t = labels_t[:, features_keep]

    x_s, x_t = torch.from_numpy(source.copy()).cuda(), torch.from_numpy(target.copy()).cuda()
    x_s.requires_grad = True

    loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.1, scaling=0.99)
    initial_losses = torch.stack(
        [loss_fn(x_s[labels_s[:, f] == 1], x_t[labels_t[:, f] == 1]) for f in range(labels_s.shape[1])]
        + [loss_fn(x_s, x_t)]).detach()

    adv_probs = torch.ones(labels_t.shape[1] + 1).cuda() / labels_t.shape[1]
    adj = torch.zeros(labels_t.shape[1] + 1).float().cuda()
    for i in range(iters):
        # worst group loss on groups plus total distribution
        losses = torch.stack(
            [loss_fn(x_s[labels_s[:, f] == 1], x_t[labels_t[:, f] == 1]) for f in range(labels_s.shape[1])]
            + [loss_fn(x_s, x_t)]) / initial_losses
        # Can also use torch.max, but sum appears to work better
        # loss = torch.sum(losses) #/ (labels_s.shape[0] + labels_t.shape[0])
        ###
        adjusted_loss = losses
        if torch.all(adj>0):
            adjusted_loss += adj/torch.sqrt(torch.tensor(labels_s.shape[1]))
        adjusted_loss = adjusted_loss/(adjusted_loss.sum())
        adv_probs = adv_probs * torch.exp(0.01 * adjusted_loss.data)
        adv_probs = adv_probs/(adv_probs.sum())
        loss = losses @ adv_probs
        ###
        print(f"Iter {i} loss: {loss.item()}")
        [g] = torch.autograd.grad(loss, [x_s])
        m = labels_s.shape[0]
        x_s.data -= lr * m * g

    x_s = x_s.detach().cpu().numpy()
    diff = x_s - source
    argsorted_diff = np.linalg.norm(diff, axis=0).argsort()[::-1]  # a feature-wise divergence array in decreasing order
    free_dim_mask = np.zeros(source.shape[1], dtype=bool)  # a mask where True mean that feature can be transported
    free_dim_mask[argsorted_diff[:n_features]] = True

    source_t = source.copy()
    source_t[:, free_dim_mask] = x_s[:, free_dim_mask]

    return source_t


def group_feature_transport2(
    source: np.ndarray,
    target: np.ndarray,
    source_groups: np.ndarray,
    target_groups: np.ndarray,
    n_features: int,
    lr=0.02,
    iters=90, init_x_s=None, real_source_groups=None, real_target_groups=None) -> np.ndarray:

    labels_s = torch.from_numpy(source_groups)
    labels_t = torch.from_numpy(target_groups)

    # Remove groups with less than 50 members
    features_keep = []
    for f in range(labels_s.shape[1]):
        if torch.sum(labels_t[:, f] == 1) < 1 or torch.sum(labels_s[:, f] == 1) < 1:
            continue
        features_keep.append(f)
    print(features_keep)
    labels_s = labels_s[:, features_keep]
    labels_t = labels_t[:, features_keep]

    if init_x_s is None:
        x_s, x_t = torch.from_numpy(source.copy()).cuda(), torch.from_numpy(target.copy()).cuda()
    else:
        x_s = torch.from_numpy(init_x_s).clone().cuda()
        x_t = torch.from_numpy(target.copy()).cuda()
    x_s.requires_grad = True

    loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.1, scaling=0.99)
    initial_losses = torch.stack(
        [loss_fn(x_s[labels_s[:, f] == 1], x_t[labels_t[:, f] == 1]) for f in range(labels_s.shape[1])]
        + [loss_fn(x_s, x_t)]).detach()

    adv_probs = torch.ones(labels_t.shape[1] + 1).cuda() / labels_t.shape[1]
    adj = torch.zeros(labels_t.shape[1] + 1).float().cuda()
    for i in range(iters):
        # worst group loss on groups plus total distribution
        losses = torch.stack(
            [loss_fn(x_s[labels_s[:, f] == 1], x_t[labels_t[:, f] == 1]) for f in range(labels_s.shape[1])])
            # + [loss_fn(x_s, x_t)])#/ initial_losses
        # Can also use torch.max, but sum appears to work better
        # loss = torch.sum(losses) #/ (labels_s.shape[0] + labels_t.shape[0])
        ###
        # adjusted_loss = losses
        # if torch.all(adj>0):
        #     adjusted_loss += adj/torch.sqrt(torch.tensor(labels_s.shape[1]))
        # adjusted_loss = adjusted_loss/(adjusted_loss.sum())
        # adv_probs = adv_probs * torch.exp(0.01 * adjusted_loss.data)
        # adv_probs = adv_probs/(adv_probs.sum())
        # loss = losses @ adv_probs
        loss = torch.max(losses)
        ###
        
        [g] = torch.autograd.grad(loss, [x_s])
        print(f"Iter {i} loss and grad: {loss.item()}, {torch.norm(g).item()}, {loss_fn(x_s, x_t).cpu().item()}")
        if real_source_groups is not None and real_target_groups is not None:
            real_labels_s = torch.from_numpy(real_source_groups)
            real_labels_t = torch.from_numpy(real_target_groups)
            # real_labels_s = real_labels_s[:, features_keep]
            # real_labels_t = real_labels_t[:, features_keep]
            group_losses = torch.stack(
            [loss_fn(x_s[real_labels_s[:, f] == 1], x_t[real_labels_t[:, f] == 1]) for f in range(real_labels_s.shape[1])])
            print(group_losses)
            
        m = labels_s.shape[0]
        x_s.data -= lr * m * g

    x_s = x_s.detach().cpu().numpy()
    diff = x_s - source
    argsorted_diff = np.linalg.norm(diff, axis=0).argsort()[::-1]  # a feature-wise divergence array in decreasing order
    free_dim_mask = np.zeros(source.shape[1], dtype=bool)  # a mask where True mean that feature can be transported
    free_dim_mask[argsorted_diff[:n_features]] = True

    source_t = source.copy()
    source_t[:, free_dim_mask] = x_s[:, free_dim_mask]

    return source_t, torch.argmax(losses)

def group_mean_shift_transport(
    source: np.ndarray,
    target: np.ndarray,
    source_groups: np.ndarray,
    target_groups: np.ndarray,
    n_features: int,
    lr=0.02,
    iters=90) -> Tuple[np.ndarray, np.ndarray]:

    labels_s = torch.from_numpy(source_groups)
    labels_t = torch.from_numpy(target_groups)
    features_keep = []
    for f in range(labels_s.shape[1]):
        if torch.sum(labels_t[:, f] == 1) < 1 or torch.sum(labels_s[:, f] == 1) < 1:
            continue
        features_keep.append(f)
    labels_s = labels_s[:, features_keep]
    labels_t = labels_t[:, features_keep]

    m = Variable(torch.zeros(source.shape[1]), requires_grad=True).cuda()
    loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.01, scaling=0.99)

    orig_source = torch.from_numpy(source.copy()).cuda()
    x_s = torch.from_numpy(source.copy()).cuda()
    x_t = torch.from_numpy(target.copy()).cuda()

    initial_losses = torch.stack(
        [loss_fn(x_s[labels_s[:, f] == 1], x_t[labels_t[:, f] == 1]) for f in range(labels_s.shape[1])]
        + [loss_fn(x_s, x_t)]).detach()
    for t in range(iters):
        x_s = orig_source.clone() + m
        losses = torch.stack(
            [loss_fn(x_s[labels_s[:, f] == 1], x_t[labels_t[:, f] == 1]) for f in range(labels_s.shape[1])]
            + [loss_fn(x_s, x_t)]) / initial_losses
        loss = torch.sum(losses) # + 0.1 * losses[-1]
        print(f"Iter {t} loss: {loss.item()}")
        [g] = torch.autograd.grad(loss, [m])
        m.data -= lr * g

    x_s = x_s.detach().cpu().numpy()
    diff = x_s - source
    argsorted_diff = np.linalg.norm(diff, axis=0).argsort()[::-1]  # a feature-wise divergence array in decreasing order
    free_dim_mask = np.zeros(source.shape[1], dtype=bool)  # a mask where True mean that feature can be transported
    free_dim_mask[argsorted_diff[:n_features]] = True

    source_t = source.copy()
    source_t[:, free_dim_mask] = x_s[:, free_dim_mask]
    shift = m.data.detach().cpu().numpy()
    shift[~free_dim_mask] = 0
    return source_t, shift


def transform_samples(new_source, orig_source, transformed_source):
    transport = transformed_source - orig_source
    knn = KNeighborsClassifier(n_neighbors=1).fit(orig_source, np.arange(orig_source.shape[0]))
    closest = knn.kneighbors(new_source, return_distance=False).flatten()
    return new_source + transport[closest, :]


def transform_samples_kmeans(new_source, centroids, shifts):
    knn = KNeighborsClassifier(n_neighbors=1).fit(centroids, np.arange(centroids.shape[0]))
    closest = knn.kneighbors(new_source, return_distance=False).flatten()
    return new_source + shifts[closest, :]


def group_kmeans_shift_transport(
    source: np.ndarray,
    target: np.ndarray,
    source_groups: np.ndarray,
    target_groups: np.ndarray,
    n_features: int,
    clusters=4,
    lr=0.02,
    iters=90) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:

    labels_s = torch.from_numpy(source_groups)
    labels_t = torch.from_numpy(target_groups)
    features_keep = []
    for f in range(labels_s.shape[1]):
        if torch.sum(labels_t[:, f] == 1) < 1 or torch.sum(labels_s[:, f] == 1) < 1:
            continue
        features_keep.append(f)
    labels_s = labels_s[:, features_keep]
    labels_t = labels_t[:, features_keep]

    kmeans = KMeans(n_clusters=clusters, random_state=0).fit(source)
    cluster_members = kmeans.labels_
    centroids = kmeans.cluster_centers_
    for c in range(clusters):
        print(f"Samples in cluster {c}: {np.sum(cluster_members == c)}")

    # One row for each cluster mean
    m = torch.zeros((clusters, source.shape[1])).cuda()
    m.requires_grad = True
    loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.01, scaling=0.99)

    orig_source = torch.from_numpy(source.copy()).cuda()
    x_s = torch.from_numpy(source.copy()).cuda()
    x_t = torch.from_numpy(target.copy()).cuda()

    initial_losses = torch.stack(
        [loss_fn(x_s[labels_s[:, f] == 1], x_t[labels_t[:, f] == 1]) for f in range(labels_s.shape[1])]
        + [loss_fn(x_s, x_t)]).detach()

    adv_probs = torch.ones(labels_t.shape[1] + 1).cuda() / labels_t.shape[1]
    adj = torch.zeros(labels_t.shape[1] + 1).float().cuda()
    for t in range(iters):
        x_s = orig_source.clone()
        for c in range(clusters):
            x_s[cluster_members == c, :] += m[c, :]
        losses = torch.stack(
            [loss_fn(x_s[labels_s[:, f] == 1], x_t[labels_t[:, f] == 1]) for f in range(labels_s.shape[1])]
            + [loss_fn(x_s, x_t)]) / initial_losses
        # loss = torch.sum(losses) # + 0.1 * losses[-1]
        ###
        adjusted_loss = losses
        if torch.all(adj>0):
            adjusted_loss += adj/torch.sqrt(torch.tensor(labels_s.shape[1]))
        adjusted_loss = adjusted_loss/(adjusted_loss.sum())
        adv_probs = adv_probs * torch.exp(0.01 * adjusted_loss.data)
        adv_probs = adv_probs/(adv_probs.sum())
        loss = losses @ adv_probs
        ###
        print(f"Iter {t} loss: {loss.item()}")
        [g] = torch.autograd.grad(loss, [m])
        m.data -= lr * g

    x_s = x_s.detach().cpu().numpy()
    diff = x_s - source
    argsorted_diff = np.linalg.norm(diff, axis=0).argsort()[::-1]  # a feature-wise divergence array in decreasing order
    free_dim_mask = np.zeros(source.shape[1], dtype=bool)  # a mask where True mean that feature can be transported
    free_dim_mask[argsorted_diff[:n_features]] = True

    source_t = source.copy()
    source_t[:, free_dim_mask] = x_s[:, free_dim_mask]
    shift = m.data.detach().cpu().numpy()
    shift[:, ~free_dim_mask] = 0

    return source_t, centroids, shift