import torch
import torch.nn as nn
import pandas as pd
import re


def vectorize(samples, references):
    unique_references, unique_references_counts = torch.unique(references, return_counts=True)
    repeat_references = references.repeat(len(unique_references), 1)
    unique_references_idx = repeat_references == unique_references.view(-1, 1)
    samples_per_reference = torch.matmul(samples.t(), unique_references_idx.t().float())
    samples_per_reference = torch.div(samples_per_reference, unique_references_counts.float()).t()

    return samples_per_reference, unique_references, unique_references_counts


class SOMUnsupervised(nn.Module):

    def __init__(self, input_dim, n_max=20, lr0=0.1, lr=0.1, lr_push=1., at=0.985,
                 ds_beta=0.5, eps_ds=1., ld=0.1, gamma=3.0, device='cpu'):
        '''
        :param input_dim:
        :param n_max:
        :param at:
        :param ds_beta:
        :param lr:
        :param eps_ds:
        :param ld:
        :param use_cuda:
        '''

        super(SOMUnsupervised, self).__init__()
        self.input_size = input_dim

        self.n_max = n_max
        self.lr0 = torch.tensor(lr0)
        self.lr = lr
        self.push = lr_push
        self.lr_push = -lr * lr_push
        self.at = at
        self.ds_beta = ds_beta
        self.eps_ds = eps_ds
        self.ld = ld

        self.gamma = gamma

        self.device = torch.device(device)

        self.node_control = nn.Parameter(torch.zeros(n_max, device=self.device), requires_grad=False)
        self.life = nn.Parameter(torch.ones(n_max, device=self.device), requires_grad=False)

        self.weights = nn.Parameter(torch.zeros(n_max, input_dim, device=self.device), requires_grad=False)
        self.moving_avg = nn.Parameter(torch.zeros(n_max, input_dim, device=self.device), requires_grad=False)
        self.relevance = nn.Parameter(torch.ones(n_max, input_dim, device=self.device), requires_grad=False)

        self.map_center = nn.Parameter(torch.zeros(input_dim, device=self.device), requires_grad=False)

    def activation(self, w):
        dists = self.weighted_distance(w)
        relevance_sum = torch.sum(self.relevance, 1)

        return torch.div(relevance_sum, torch.add(torch.add(relevance_sum, dists), 1e-3))

    def weighted_distance(self, w, prototypes=None, relevance=None):
        if prototypes is None:
            prototypes = self.weights

        if relevance is None:
            relevance = self.relevance

        # dists = torch.cdist(w, prototypes, p=2, compute_mode='use_mm_for_euclid_dist')

        dists = w.unsqueeze(1) - prototypes.unsqueeze(0)
        dists = torch.pow(dists, 2)
        dist_weight = relevance.unsqueeze(0) * dists
        dist_weight = torch.sum(dist_weight, -1)
        # dist_weight = torch.sqrt(dist_weight)

        return dist_weight

    def neighborhood_function(self, neighbors):
        return torch.exp(-neighbors / self.gamma)

    def add_node(self, new_samples):
        # number of available nodes in the map
        available_idx = (self.node_control == 0).nonzero()
        n_available = torch.tensor(available_idx.size(0))

        # number of nodes to be inserted in the map
        n_new = torch.tensor(new_samples.size(0))

        # feasible number of nodes to be created
        n_create = torch.min(n_available, n_new)

        # decides the indexes of new samples that will be inserted in fact (less activated to higher activated)
        max_idx = torch.max(torch.tensor(0), torch.tensor(new_samples.size(0)))
        min_idx = torch.max(max_idx - n_create, torch.tensor(0))
        create_idx = torch.arange(start=min_idx, end=max_idx, step=1)
        new_nodes = new_samples[create_idx]

        n_new_nodes = new_nodes.size(0)
        new_nodes_idx = available_idx[:n_new_nodes].t()

        self.node_control[new_nodes_idx] = 1.
        self.life[new_nodes_idx] = 1.
        self.weights[new_nodes_idx] = nn.Parameter(new_nodes, requires_grad=False)
        self.relevance[new_nodes_idx] = nn.Parameter(torch.ones(n_new_nodes, self.input_size, device=self.device),
                                                     requires_grad=False)

        self.moving_avg[new_nodes_idx] = nn.Parameter(torch.zeros(n_new_nodes, self.input_size, device=self.device),
                                                      requires_grad=False)

        return new_nodes_idx.squeeze(-1), create_idx

    def update_node(self, w, index, lr):
        distance = torch.abs(torch.sub(w, self.weights[index]))
        self.moving_avg[index] = (torch.mul(lr * self.ds_beta, distance) +
                                  torch.mul(1 - (lr * self.ds_beta), self.moving_avg[index]))

        maximum = torch.max(self.moving_avg[index], dim=1, keepdim=True)[0]
        minimum = torch.min(self.moving_avg[index], dim=1, keepdim=True)[0]
        avg = torch.mean(self.moving_avg[index], dim=1, keepdim=True)

        one_tensor = torch.tensor(1, dtype=torch.float, device=self.device)

        self.relevance[index] = torch.div(one_tensor,
                                          one_tensor + torch.exp(torch.div(torch.sub(self.moving_avg[index], avg),
                                                                           torch.mul(self.eps_ds,
                                                                                     torch.sub(maximum, minimum)))))
        # print("relevances:", self.relevance)
        # if (max - min) == 0 or (mv_avg - avg) == 0 then set to 1
        self.relevance[self.relevance != self.relevance] = 1.

        delta = torch.mul(lr, torch.sub(w, self.weights[index]))
        self.weights[index] = torch.add(self.weights[index], delta)

        dist_center = torch.sub(self.map_center, self.weights[index])
        respulsive_fn = 1. - 2. * dist_center
        inv_relu = torch.clamp(respulsive_fn, min=0)
        self.weights[index] = torch.sub(self.weights[index], self.lr * self.push * inv_relu)

    def get_map_dists(self, input):
        nodes_mask = self.node_control.bool()
        nodes = self.weights[nodes_mask]
        relevance = self.relevance[nodes_mask]

        dists = self.weighted_distance(input, prototypes=nodes, relevance=relevance)

        sorted_dists, _ = torch.sort(dists)
        nn_function = self.neighborhood_function(torch.arange(0, len(nodes), device=self.device))
        sorted_dists *= nn_function

        return sorted_dists

    def get_winners(self, input):
        activations = self.activation(input) * self.node_control
        act_max, indexes_max = torch.max(activations, dim=1)

        return act_max, indexes_max

    def get_prototypes(self):
        mask = self.node_control != 0
        return self.weights[mask], self.relevance[mask], self.moving_avg[mask]

    def forward(self, input):
        '''
        Find the location of best matching unit.
        :param input: data
        :param lr: learning rate
        :return: loss, location of best matching unit
        '''

        act_max, indexes_max = self.get_winners(input)
        node_mask = self.node_control.bool()
        self.map_center = nn.Parameter(torch.mean(self.weights[node_mask], 0), requires_grad=False)

        self.life -= self.ld

        bool_high_at = act_max >= self.at
        samples_high_at = input[bool_high_at]
        nodes_high_at = indexes_max[bool_high_at]

        updatable_samples_hight_at = []
        unique_nodes_high_at = []
        if len(nodes_high_at) > 0:
            self.node_control[nodes_high_at] = 1.
            updatable_samples_hight_at, unique_nodes_high_at, unique_counts = vectorize(samples_high_at,
                                                                                        references=nodes_high_at)

            activations = self.activation(updatable_samples_hight_at) * self.node_control
            _, ranked_nodes = torch.topk(activations, self.n_max, dim=1)

            for i in range(len(unique_nodes_high_at)):
                node = unique_nodes_high_at[i]
                self.life[node] = 1.
                # self.restore_life(node, sample.size(0), input.size(0))

                sample = updatable_samples_hight_at[i]
                # counts = unique_counts[i]

                nodes = self.weights[node_mask]
                local_nn = ranked_nodes[i][:len(nodes)]
                neighborhood_fn = self.neighborhood_function(torch.arange(0, len(nodes), device=self.device))
                # local_nn_lr = self.lr * neighborhood_fn * counts
                local_nn_lr = self.lr * neighborhood_fn

                with torch.no_grad():
                    self.update_node(sample.repeat(len(nodes), 1),
                                     local_nn,
                                     local_nn_lr.repeat(sample.shape[0], 1).t())

        bool_low_at = act_max < self.at
        samples_low_at = input[bool_low_at]
        nodes_low_at = indexes_max[bool_low_at]

        # if there is nodes to be inserted and positions available in the map
        if len(nodes_low_at) > 0 and self.node_control[self.node_control == 0].size(0) > 0:
            create_samples_low_at, _, _ = vectorize(samples_low_at, references=nodes_low_at)

            with torch.no_grad():
                new_nodes, _ = self.add_node(create_samples_low_at)

        self.remove_nodes()

        final_dists = [] # self.get_map_dists(input)
        _, final_winners = self.get_winners(input)
        return (updatable_samples_hight_at, self.weights[unique_nodes_high_at], self.relevance[unique_nodes_high_at],
                final_dists, final_winners)

    def restore_life(self, node, samples_size, batch_size):
        self.life[node] = 1.
        # self.life[node] += (self.ld + (self.ld * samples_size / batch_size))
        # self.life = torch.clamp(self.life, min=0.0, max=1.0)

    def remove_nodes(self):
        dead_nodes = self.life <= 0.
        self.node_control[dead_nodes] = 0.

    def remove_nodes_by_idx(self, indexes):
        self.node_control[indexes] = 0.

    def cluster(self, dataloader):
        clustering = pd.DataFrame(columns=['sample_ind', 'cluster'])
        predicted_clusters = []
        true_labels = []

        for batch_idx, (inputs, targets) in enumerate(dataloader):
            _, bmu_indexes = self.get_winners(inputs.to(self.device))

            for index, bmu_index in enumerate(bmu_indexes):
                winner = bmu_index.item()

                clustering = clustering.append({'sample_ind': batch_idx,
                                                'cluster': winner},
                                               ignore_index=True)
                predicted_clusters.append(winner)
                true_labels.append(targets[index].item())

        return predicted_clusters, true_labels, clustering

    def write_output(self, output_path, cluster_result, elapsed_time=None):
        print(output_path)
        output_file = open(output_path, 'w+')

        n_clusters = self.node_control[self.node_control == 1].size(0)

        content = str(n_clusters) + "\t" + str(self.input_size)

        if elapsed_time is not None:
            content += "\t" + str(elapsed_time)

        content += "\n"

        for i, relevance in enumerate(self.relevance):
            if self.node_control[i] == 1:
                with torch.no_grad():
                    content += str(i) + "\t" + "\t".join(map(str, relevance.cpu().numpy())) + "\n"

        result_text = cluster_result.to_string(header=False, index=False).strip()
        result_text = re.sub('\n +', '\n', result_text)
        result_text = re.sub(' +', '\t', result_text)

        content += result_text
        output_file.write(content)
        output_file.close()


class SOMSemiSupervised(SOMUnsupervised):

    def __init__(self, input_dim, n_max=20, lr0=0.1, lr=0.1, lr_push=0.4, at=0.985, ds_beta=0.5,
                 eps_ds=1., ld=0.1, gamma=3.0, no_class=999, device='cpu'):
        '''
        :param input_dim:
        :param n_max:
        :param at:
        :param ds_beta:
        :param lr:
        :param lr_push:
        :param eps_ds:
        :param ld:
        :param use_cuda:
        '''

        super(SOMSemiSupervised, self).__init__(input_dim, n_max, lr0, lr, lr_push, at,
                                                ds_beta, eps_ds, ld, gamma, device)

        self.no_class = no_class

        self.classes = nn.Parameter(torch.full((n_max,), self.no_class, device=self.device, dtype=torch.int64),
                                    requires_grad=False)

        self.diff_update = 0
        self.positive_update = 0
        self.dup_update = 0
        self.create = 0
        self.create_low_at = 0

    def add_node(self, new_samples, new_targets=None):
        new_nodes_idx, create_idx = super(SOMSemiSupervised, self).add_node(new_samples)

        self.classes[new_nodes_idx] = new_targets[create_idx]

        return new_nodes_idx, create_idx

    def get_map_dists(self, input, targets=None):
        _, _, _, _, topk_acts, topk_idx, all_wnrs_mask = self.get_1st_2nd_all_winners(input, targets)

        final_wnrs = topk_idx.t()[0]
        dists = self.weighted_distance(input) * all_wnrs_mask
        sum_dists = torch.sum(dists, 1)

        return sum_dists, final_wnrs

    def get_winners(self, input, targets=None):
        _, _, _, _, topk_acts, topk_idx, all_wnrs_mask = self.get_1st_2nd_all_winners(input, targets)

        final_wnrs = topk_idx.t()[0]
        first_wnrs_act = topk_acts.t()[0]

        return first_wnrs_act, final_wnrs

    def get_1st_2nd_winners(self, samples, targets, k=2):
        activations = self.activation(samples) * self.node_control

        topk_act, topk_idx = torch.topk(activations, k, dim=1)
        topk_act_t, topk_idx_t = topk_act.t(), topk_idx.t()

        first_winners_act, first_winners = topk_act_t[0], topk_idx_t[0]
        second_winners_act, second_winners = topk_act_t[1], topk_idx_t[1]

        eq_cls = (self.classes[second_winners] == targets)
        no_cls = (self.classes[second_winners] == self.no_class)

        second_winners_mask = (eq_cls | no_cls) & self.node_control[second_winners].bool()

        return first_winners_act, first_winners, second_winners, second_winners_mask, topk_act, topk_idx

    def get_1st_2nd_all_winners(self, samples, targets):
        (first_wnrs_act, first_wnrs,
         second_wnrs, second_wnrs_mask,
         topk_act, topk_idx) = self.get_1st_2nd_winners(samples, targets, k=self.n_max)

        all_eq_cls = (self.classes.repeat(len(targets), 1).t() == targets).t()
        all_no_cls = (self.classes.repeat(len(targets), 1).t() == self.no_class).t()
        all_wnrs_criteria_mask = (all_eq_cls | all_no_cls)

        all_wnrs_nn_mask = torch.arange(0, self.n_max, device=self.device).repeat(len(targets), 1)
        all_wnrs_nn_mask = (self.neighborhood_function(all_wnrs_nn_mask) * self.node_control) > 0
        all_wnrs_mask = all_wnrs_criteria_mask & all_wnrs_nn_mask

        return first_wnrs_act, first_wnrs, second_wnrs, second_wnrs_mask, topk_act, topk_idx, all_wnrs_mask

    def forward(self, input, targets=None):
        '''
        Find the location of best matching unit.s
        :param input: data
        :param lr: learning rate
        :return:
        '''

        backward_winners = torch.tensor([], dtype=torch.long)
        backward_samples = torch.tensor([])

        (wnrs_act, wnrs,
         sec_wnrs, sec_wnrs_mask, _, all_wnrs, all_wnrs_mask) = self.get_1st_2nd_all_winners(input, targets)

        self.life -= self.ld

        bool_high_at = wnrs_act >= self.at
        samples_high_at = input[bool_high_at]
        targets_high_at = targets[bool_high_at]
        winners_high_at = wnrs[bool_high_at]
        if len(winners_high_at) > 0:
            backward_samples, backward_winners = self.rules_and_update(samples=samples_high_at,
                                                                       targets=targets_high_at,
                                                                       winners=winners_high_at,
                                                                       second_winners=sec_wnrs[bool_high_at],
                                                                       second_winners_mask=sec_wnrs_mask[bool_high_at],
                                                                       all_nn=all_wnrs[bool_high_at],
                                                                       all_nn_mask=all_wnrs_mask[bool_high_at])

        bool_low_at = wnrs_act < self.at
        samples_low_at = input[bool_low_at]
        targets_low_at = targets[bool_low_at]
        nodes_low_at = wnrs[bool_low_at]

        # if there is nodes to be inserted and positions available in the map
        if len(nodes_low_at) > 0 and self.node_control[self.node_control == 0].size(0) > 0:
            create_samples_low_at, create_targets_low_at, _ = vectorize(samples_low_at,
                                                                        references=targets_low_at)
            self.add_node(create_samples_low_at, create_targets_low_at)

        self.remove_nodes()

        final_sum_dists, final_winners = self.get_map_dists(input, targets)
        return (backward_samples, self.weights[backward_winners], self.relevance[backward_winners],
                final_sum_dists, final_winners)

    def rules_and_update(self, samples, targets, winners,
                         second_winners, second_winners_mask, all_nn, all_nn_mask):
        unique_winners, unique_winners_counts = torch.unique(winners, return_counts=True)

        return_winners = torch.tensor([], dtype=torch.long, device=self.device)
        return_samples = torch.tensor([], device=self.device)

        for winner in unique_winners:
            winner_array_like = winner.unsqueeze(-1)
            self.life[winner] = 1.
            samples_per_winner = samples[winners == winner]
            targets_per_winner = targets[winners == winner]

            samples_second_winners = second_winners[winners == winner]
            samples_second_winners_mask = second_winners_mask[winners == winner]
            all_winner_nn = all_nn[winners == winner]
            all_winner_mask_nn = all_nn_mask[winners == winner]

            """
                If a sample has a defined class but the node do not, then "simulate" a labeling by duplicating 
                the original winner with the label of this input sample.
            """
            duplicate_update_mask = ((targets_per_winner != self.classes[winner]) &
                                     (self.classes[winner] == self.no_class) &
                                     (targets_per_winner != self.no_class))

            duplicate_samples = samples_per_winner[duplicate_update_mask]
            duplicate_targets = targets_per_winner[duplicate_update_mask]
            unq_samples_dup, unq_targets_dup, unq_targets_counts_dup = vectorize(duplicate_samples,
                                                                                 references=duplicate_targets)

            duplicate_w = self.weights[winner].repeat(len(unq_samples_dup), 1)
            duplicate_node, _ = self.add_node(duplicate_w, duplicate_targets)
            duplicate_node = duplicate_node.reshape(-1)

            """
                If both sample and node have the same class, than just update as usual.
            """
            positive_update_mask = (targets_per_winner == self.classes[winner]) | (targets_per_winner == self.no_class)

            """
                For samples that not fit the previous conditions, try to find a second winner 
                with no_class or the same class and highly activated according at threshold
            """
            diff_update_mask = ~(duplicate_update_mask | positive_update_mask) & samples_second_winners_mask

            """
                For samples that not fit the previous conditions and not have a second winner, create new nodes
            """
            create_mask = ~(duplicate_update_mask | positive_update_mask | diff_update_mask)

            with torch.no_grad():
                '''
                    update the neighborhood around the original node (that will be replaced by the new labeled ones  
                '''
                nn_dup_nodes_mask = all_winner_mask_nn[duplicate_update_mask]
                nn_dup_nodes = all_winner_nn[duplicate_update_mask][nn_dup_nodes_mask]
                nn_dup_nodes = torch.unique(torch.cat((winner_array_like, nn_dup_nodes)))

                if duplicate_samples.size(0) > 0 and duplicate_node.nelement() > 0:
                    samples_duplicate_lr = unq_targets_counts_dup * 0 + self.lr

                    samples_dup_limit = unq_samples_dup[0: duplicate_node.nelement()]
                    samples_duplicate_lr = samples_duplicate_lr[0: duplicate_node.nelement()]
                    self.update_node(samples_dup_limit,
                                     duplicate_node,
                                     samples_duplicate_lr.repeat(duplicate_samples.size(1), 1).t())

                    return_winners = torch.cat((return_winners, duplicate_node))
                    return_samples = torch.cat((return_samples, samples_dup_limit))

                    nn_function_in = torch.arange(0, len(nn_dup_nodes), device=self.device)
                    samples_nn_dup_lr = self.lr * self.neighborhood_function(nn_function_in)
                    self.update_node(samples_dup_limit.mean(0).repeat(nn_dup_nodes.nelement(), 1),
                                     nn_dup_nodes,
                                     samples_nn_dup_lr.repeat(samples_dup_limit.size(1), 1).t())

                    self.dup_update += 1

                samples_positive = samples_per_winner[positive_update_mask]
                positive_nodes_mask = all_winner_mask_nn[positive_update_mask]
                positive_nodes = all_winner_nn[positive_update_mask][positive_nodes_mask]
                positive_nodes = torch.unique(torch.cat((winner_array_like, positive_nodes)))
                if samples_positive.size(0) > 0:
                    nn_function_in = torch.arange(0, len(positive_nodes), device=self.device)

                    samples_positive_lr = self.lr * self.neighborhood_function(nn_function_in)
                    w_update = samples_positive.mean(0).repeat(positive_nodes.nelement(), 1)
                    self.update_node(w_update,
                                     positive_nodes,
                                     samples_positive_lr.repeat(samples_positive.size(1), 1).t())

                    self.life[winner_array_like] = 1.

                    return_winners = torch.cat((return_winners, positive_nodes))
                    return_samples = torch.cat((return_samples, w_update))
                    self.positive_update += 1

                samples_diff = samples_per_winner[diff_update_mask]
                if samples_diff.size(0) > 0:
                    samples_diff_lr_push = self.lr_push

                    self.update_node(samples_diff.mean(0), winner_array_like, samples_diff_lr_push)

                    unq_samples_diff, unq_2nd_nodes_diff, unq_2nd_nodes_counts_diff = vectorize(
                        samples_diff,
                        references=samples_second_winners[diff_update_mask])

                    samples_diff_lr = unq_2nd_nodes_counts_diff * self.lr  # / batch_size
                    samples_diff_lr = samples_diff_lr * 0 + self.lr
                    self.update_node(unq_samples_diff,
                                     unq_2nd_nodes_diff,
                                     samples_diff_lr.repeat(samples_diff.size(1), 1).t())

                    self.life[unq_2nd_nodes_diff] = 1.

                    return_winners = torch.cat((return_winners, unq_2nd_nodes_diff))
                    return_samples = torch.cat((return_samples, unq_samples_diff))

                    self.diff_update += 1

                samples_create = samples_per_winner[create_mask]
                if samples_create.size(0) > 0:
                    unq_samples_create, unq_targets_create, _ = vectorize(samples_create,
                                                                          references=targets_per_winner[create_mask])
                    self.add_node(unq_samples_create, unq_targets_create)

                    self.create += 1

                remove_undef_condition = (duplicate_samples.size(0) > 0) and (duplicate_node.nelement() > 0)
                remove_undef_node_idx = winner_array_like[remove_undef_condition]
                self.remove_nodes_by_idx(remove_undef_node_idx)

        return return_samples, return_winners

    def cluster(self, dataloader):
        clustering = pd.DataFrame(columns=['sample_ind', 'cluster', 'class'])
        predicted_clusters = []
        predicted_labels = []
        true_labels = []

        for batch_idx, (inputs, targets) in enumerate(dataloader):
            _, bmus = self.get_winners(inputs.to(self.device), targets.to(self.device))

            for i in range(inputs.size(0)):
                winner = bmus[i].item()

                predicted_label = self.classes[winner].item()
                clustering = clustering.append({'sample_ind': batch_idx,
                                                'cluster': winner,
                                                'class': predicted_label},
                                               ignore_index=True)

                predicted_labels.append(predicted_label)
                predicted_clusters.append(winner)
                true_labels.append(targets[i].item())

        return predicted_clusters, predicted_labels, true_labels, clustering
