from math import sqrt
from torch.nn import Conv2d, Linear
from torch.nn.init import calculate_gain
from torch import where, rand, topk, long, empty, zeros, no_grad, tensor
import torch
import sys
from .gnt import AdamGnT



def get_layer_bound(layer, init, gain):
    if isinstance(layer, Conv2d):
        return sqrt(1 / (layer.in_channels * layer.kernel_size[0] * layer.kernel_size[1]))
    elif isinstance(layer, Linear):
        if init == 'default':
            bound = sqrt(1 / layer.in_features)
        elif init == 'xavier':
            bound = gain * sqrt(6 / (layer.in_features + layer.out_features))
        elif init == 'lecun':
            bound = sqrt(3 / layer.in_features)
        else:
            bound = gain * sqrt(3 / layer.in_features)
        return bound


class ConvGnT(object):
    """
    Generate-and-Test algorithm for ConvNets, maturity threshold based tester, accumulates probability of replacement,
    with various measures of feature utility
    """
    def __init__(self, net, hidden_activation, opt, decay_rate=0.99, replacement_rate=1e-4, init='kaiming',
                 num_last_filter_outputs=4, util_type='contribution', maturity_threshold=100, device='cpu'):
        super(ConvGnT, self).__init__()

        self.net = net
        self.num_hidden_layers = int(len(self.net)/2)
        self.util_type = util_type
        self.device = device

        self.opt = opt
        self.opt_type = 'sgd'
        if isinstance(self.opt, AdamGnT):
            self.opt_type = 'AdamGnT'

        """
        Define the hyper-parameters of the algorithm
        """
        self.replacement_rate = replacement_rate
        self.decay_rate = decay_rate
        self.num_last_filter_outputs = num_last_filter_outputs
        self.maturity_threshold = maturity_threshold
        self.util_type = util_type

        """
        Utility of all features/neurons
        """
        self.util, self.bias_corrected_util, self.ages, self.mean_feature_act, self.mean_abs_feature_act, \
             = [], [], [], [], []

        for i in range(self.num_hidden_layers):
            if isinstance(self.net[i * 2], Conv2d):
                self.util.append(zeros(self.net[i * 2].out_channels))
                self.bias_corrected_util.append(zeros(self.net[i * 2].out_channels))
                self.ages.append(zeros(self.net[i * 2].out_channels))
                self.mean_feature_act.append(zeros(self.net[i * 2].out_channels))
                self.mean_abs_feature_act.append(zeros(self.net[i * 2].out_channels))
            elif isinstance(self.net[i * 2], Linear):
                self.util.append(zeros(self.net[i * 2].out_features))
                self.bias_corrected_util.append(zeros(self.net[i * 2].out_features))
                self.ages.append(zeros(self.net[i * 2].out_features))
                self.mean_feature_act.append(zeros(self.net[i * 2].out_features))
                self.mean_abs_feature_act.append(zeros(self.net[i * 2].out_features))

        self.accumulated_num_features_to_replace = [0 for i in range(self.num_hidden_layers)]
        self.m = torch.nn.Softmax(dim=1)

        """
        Calculate uniform distribution's bound for random feature initialization
        """
        if hidden_activation == 'selu': init = 'lecun'
        self.bounds = self.compute_bounds(hidden_activation=hidden_activation, init=init)
        """
        Pre calculate number of features to replace per layer per update
        """
        self.num_new_features_to_replace = []
        for i in range(self.num_hidden_layers):
            with no_grad():
                if isinstance(self.net[i * 2], Linear):
                    self.num_new_features_to_replace.append(self.replacement_rate * self.net[i * 2].out_features)
                elif isinstance(self.net[i * 2], Conv2d):
                    self.num_new_features_to_replace.append(self.replacement_rate * self.net[i * 2].out_channels)

    def compute_bounds(self, hidden_activation, init='kaiming'):
        if hidden_activation in ['swish', 'elu']: hidden_activation = 'relu'
        bounds = []
        gain = calculate_gain(nonlinearity=hidden_activation)
        for i in range(self.num_hidden_layers):
            bounds.append(get_layer_bound(layer=self.net[i * 2], init=init, gain=gain))
        bounds.append(get_layer_bound(layer=self.net[-1], init=init, gain=1))
        return bounds

    def update_utility(self, layer_idx=0, features=None):
        with torch.no_grad():
            self.util[layer_idx] *= self.decay_rate
            bias_correction = 1 - self.decay_rate ** self.ages[layer_idx]

            current_layer = self.net[layer_idx * 2]
            next_layer = self.net[layer_idx * 2 + 2]

            if isinstance(next_layer, Linear):
                output_wight_mag = next_layer.weight.data.abs().mean(dim=0)
            elif isinstance(next_layer, Conv2d):
                output_wight_mag = next_layer.weight.data.abs().mean(dim=(0, 2, 3))

            self.mean_feature_act[layer_idx] *= self.decay_rate
            self.mean_abs_feature_act[layer_idx] *= self.decay_rate
            if isinstance(current_layer, Linear):
                input_wight_mag = current_layer.weight.data.abs().mean(dim=1)
                self.mean_feature_act[layer_idx] -=- (1 - self.decay_rate) * features.mean(dim=0)
                self.mean_abs_feature_act[layer_idx] -=- (1 - self.decay_rate) * features.abs().mean(dim=0)
            elif isinstance(current_layer, Conv2d):
                input_wight_mag = current_layer.weight.data.abs().mean(dim=(1, 2, 3))
                if isinstance(next_layer, Conv2d):
                    self.mean_feature_act[layer_idx] -=- (1 - self.decay_rate) * features.mean(dim=(0, 2, 3)).cpu()
                    self.mean_abs_feature_act[layer_idx] -=- (1 - self.decay_rate) * features.abs().mean(dim=(0, 2, 3)).cpu()
                else:
                    self.mean_feature_act[layer_idx] -=- (1 - self.decay_rate) * features.mean(dim=0).view(self.num_last_filter_outputs, -1).mean(dim=1).cpu()
                    self.mean_abs_feature_act[layer_idx] -=- (1 - self.decay_rate) * features.abs().mean(dim=0).view(self.num_last_filter_outputs, -1).mean(dim=1).cpu()

            bias_corrected_act = self.mean_feature_act[layer_idx] / bias_correction

            if self.util_type == 'adaptation':
                new_util = 1 / input_wight_mag
            elif self.util_type in ['contribution', 'zero_contribution', 'adaptable_contribution']:
                if self.util_type == 'contribution':
                    bias_corrected_act = 0
                else:
                    if isinstance(current_layer, Conv2d):
                        if isinstance(next_layer, Conv2d):
                            bias_corrected_act = bias_corrected_act.view(1, -1, 1, 1)
                        else:
                            bias_corrected_act = bias_corrected_act.repeat_interleave(self.num_last_filter_outputs).view(1, -1)
                if isinstance(next_layer, Linear):
                    if isinstance(current_layer, Linear):
                        new_util = output_wight_mag * (features - bias_corrected_act).abs().mean(dim=0)
                    elif isinstance(current_layer, Conv2d):
                        new_util = output_wight_mag * (features - bias_corrected_act).abs().mean(dim=0).view(self.num_last_filter_outputs, -1).mean(dim=1)
                elif isinstance(next_layer, Conv2d):
                   # breakpoint()
                    new_util = output_wight_mag * (features - bias_corrected_act).abs().mean(dim=(0, 2, 3))
                if self.util_type == 'adaptable_contribution':
                    new_util = new_util / input_wight_mag

            if self.util_type == 'random':
                self.bias_corrected_util[layer_idx] = rand(self.util[layer_idx].shape)
            else:
                self.util[layer_idx] -=- (1 - self.decay_rate) * new_util.cpu()
                # correct the bias in the utility computation
                self.bias_corrected_util[layer_idx] = self.util[layer_idx] / bias_correction

    def test_features(self, features):
        """
        Args:
            features: Activation values in the neural network
        Returns:
            Features to replace in each layer, Number of features to replace in each layer
        """
        features_to_replace_input_indices = [empty(0, dtype=long) for _ in range(self.num_hidden_layers)]
        features_to_replace_output_indices = [empty(0, dtype=long) for _ in range(self.num_hidden_layers)]
        num_features_to_replace = [0 for _ in range(self.num_hidden_layers)]
        if self.replacement_rate == 0:
            return features_to_replace_input_indices, features_to_replace_output_indices, num_features_to_replace

        for i in range(self.num_hidden_layers):
            self.ages[i] += 1
            """
            Update feature utility
            """
            #breakpoint()
            #print(i, len(features), self.num_hidden_layers)
            self.update_utility(layer_idx=i, features=features[i])
            """
            Find the no. of features to replace
            """
            eligible_feature_indices = where(self.ages[i] > self.maturity_threshold)[0]
            if eligible_feature_indices.shape[0] == 0:
                continue
            self.accumulated_num_features_to_replace[i] -=- self.num_new_features_to_replace[i]

            """
            Case when the number of features to be replaced is between 0 and 1.
            """
            num_new_features_to_replace = int(self.accumulated_num_features_to_replace[i])
            self.accumulated_num_features_to_replace[i] -= num_new_features_to_replace

            if num_new_features_to_replace == 0:    continue

            """
            Find features to replace in the current layer
            """
            new_features_to_replace = topk(-self.bias_corrected_util[i][eligible_feature_indices],
                                           num_new_features_to_replace)[1]
            new_features_to_replace = eligible_feature_indices[new_features_to_replace]

            """
            Initialize utility for new features
            """
            self.util[i][new_features_to_replace] = 0
            self.mean_feature_act[i][new_features_to_replace] = 0.
            self.mean_abs_feature_act[i][new_features_to_replace] = 0.

            num_features_to_replace[i] = num_new_features_to_replace
            features_to_replace_input_indices[i] = new_features_to_replace
            features_to_replace_output_indices[i] = new_features_to_replace
            if isinstance(self.net[i * 2], Conv2d) and isinstance(self.net[i * 2 + 2], Linear):
                num_last_filter_outputs = self.num_last_filter_outputs // self.net[i * 2].out_channels
                features_to_replace_output_indices[i] = \
                    (new_features_to_replace*num_last_filter_outputs).repeat_interleave(num_last_filter_outputs) + \
                    tensor([i for i in range(num_last_filter_outputs)]).repeat(new_features_to_replace.size()[0])

        return features_to_replace_input_indices, features_to_replace_output_indices, num_features_to_replace

    def update_optim_params(self, features_to_replace_input_indices, features_to_replace_output_indices, num_features_to_replace):
        """
        Update Optimizer's state
        """
        if self.opt_type == 'AdamGnT':
            for i in range(self.num_hidden_layers):
                # input weights
                if num_features_to_replace == 0:
                    continue
                # input weights
                self.opt.state[self.net[i * 2].bias]['exp_avg'][features_to_replace_input_indices[i]] = 0.0
                self.opt.state[self.net[i * 2].weight]['exp_avg_sq'][features_to_replace_input_indices[i], :] = 0.0
                self.opt.state[self.net[i * 2].bias]['exp_avg_sq'][features_to_replace_input_indices[i]] = 0.0
                self.opt.state[self.net[i * 2].weight]['step'][features_to_replace_input_indices[i], :] = 0
                self.opt.state[self.net[i * 2].bias]['step'][features_to_replace_input_indices[i]] = 0
                # output weights
                self.opt.state[self.net[i * 2 + 2].weight]['exp_avg'][:, features_to_replace_output_indices[i]] = 0.0
                self.opt.state[self.net[i * 2 + 2].weight]['exp_avg_sq'][:, features_to_replace_output_indices[i]] = 0.0
                self.opt.state[self.net[i * 2 + 2].weight]['step'][:, features_to_replace_output_indices[i]] = 0

    def gen_new_features(self, features_to_replace_input_indices, features_to_replace_output_indices, num_features_to_replace):
        """
        Generate new features: Reset input and output weights for low utility features
        """
        with torch.no_grad():
            for i in range(self.num_hidden_layers):
                if num_features_to_replace[i] == 0:
                    continue
                current_layer = self.net[i * 2]
                next_layer = self.net[i * 2 + 2]

                if isinstance(current_layer, Linear):
                    current_layer.weight.data[features_to_replace_input_indices[i], :] *= 0.0
                    current_layer.weight.data[features_to_replace_input_indices[i], :] -= - \
                        empty(num_features_to_replace[i], current_layer.in_features).uniform_(-self.bounds[i],
                                                                                                self.bounds[i]).to(self.device)
                elif isinstance(current_layer, Conv2d):
                    #breakpoint()
                    current_layer.weight.data[features_to_replace_input_indices[i], :] *= 0.0
                    current_layer.weight.data[features_to_replace_input_indices[i], :] -= - \
                        empty([num_features_to_replace[i]] + list(current_layer.weight.shape[1:])). \
                            uniform_(-self.bounds[i], self.bounds[i]).to(self.device)

                current_layer.bias.data[features_to_replace_input_indices[i]] *= 0.0
                """
                # Set the outgoing weights and ages to zero
                """
                next_layer.weight.data[:, features_to_replace_output_indices[i]] = torch.zeros_like(next_layer.weight.data[:, features_to_replace_output_indices[i]]).to(self.device)
                self.ages[i][features_to_replace_input_indices[i]] = 0

    def gen_and_test(self, features):
        """
        Perform generate-and-test
        :param features: activation of hidden units in the neural network
        """
        if not isinstance(features, list):
            print('features passed to generate-and-test should be a list')
            sys.exit()
        features_to_replace_input_indices, features_to_replace_output_indices, num_features_to_replace = self.test_features(features=features)
        self.gen_new_features(features_to_replace_input_indices, features_to_replace_output_indices, num_features_to_replace)
        self.update_optim_params(features_to_replace_input_indices, features_to_replace_output_indices, num_features_to_replace)