import torch
import torch.nn as nn
from thop import profile
import numpy as np
from unit_testing import tester
from log import log_class
from hyperparameters import search_space
import importlib
import math

from collections import defaultdict


def get_layer_metric_array(net, metric, mode):
    metric_array = []

    for layer in net.modules():
        if mode == 'channel' and hasattr(layer, 'dont_ch_prune'):
            continue
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv1d):
            metric_array.append(metric(layer))

    return metric_array

    # select the gradients that we want to use for search/prune

def synflow(layer):
    if layer.weight.grad is not None:
        return torch.abs(layer.weight * layer.weight.grad)
    else:
        return torch.zeros_like(layer.weight)

def compute_synflow_per_weight(net , input_data , targets=None , mode=None , split_data=1 , loss_fn=None):

    if isinstance(input_data, torch.Tensor):
        device = input_data.device
    if isinstance(input_data, tuple):
        device = input_data[0].device

    # convert params to their abs. Keep sign for converting it back.
    @torch.no_grad()
    def linearize(net):
        signs = {}
        for name, param in net.state_dict().items():
            signs[name] = torch.sign(param)
            param.abs_()
        return signs

    # convert to orig values
    @torch.no_grad()
    def nonlinearize(net, signs):
        for name, param in net.state_dict().items():
            if 'weight_mask' not in name:
                param.mul_(signs[name])

    # keep signs of all params
    signs = linearize(net)

    # Compute gradients with input of 1s
    net.zero_grad()
    # net.double()

    if isinstance(input_data, torch.Tensor):
        input_dim = list(input_data[0, :].shape)
        if input_data.dtype == torch.int64 or input_data.dtype == torch.int32 or input_data.dtype == torch.int16 or input_data.dtype == torch.int8 or input_data.dtype == torch.int or input_data.dtype == torch.long:
            transformed_input_data = torch.ones([1] + input_dim, requires_grad=True).long().to(device)
        if input_data.dtype == torch.float64 or input_data.dtype == torch.float32 or input_data.dtype == torch.float16 or input_data.dtype == torch.float or input_data.dtype == torch.half:
            transformed_input_data = torch.ones([1] + input_dim, requires_grad=True).double().to(device)
        output = net.forward(transformed_input_data)
    if isinstance(input_data, tuple):
        transformed_input_data = ()
        for single_input in input_data:
            single_input_dim = list(single_input[0, :].shape)
            if single_input.dtype == torch.int64 or single_input.dtype == torch.int32 or single_input.dtype == torch.int16 or single_input.dtype == torch.int8 or single_input.dtype == torch.int or single_input.dtype == torch.long:
                single_input = torch.ones([1] + single_input_dim , requires_grad=True).long().to(device)
            if single_input.dtype == torch.float64 or single_input.dtype == torch.float32 or single_input.dtype == torch.float16 or single_input.dtype == torch.float or single_input.dtype == torch.half:
                single_input = torch.ones([1] + single_input_dim , requires_grad=True).double().to(device)
            transformed_input_data=transformed_input_data+(single_input,)
        output = net.forward(*transformed_input_data)





    if type(output) is tuple:
        # (feat, logits) = output
        _, output = output

    torch.sum(output.double()).backward()

    grads_abs = get_layer_metric_array(net, synflow, mode)

    # apply signs of all params
    nonlinearize(net, signs)

    return grads_abs


def hooklogdet(K, labels=None):
    s, ld = np.linalg.slogdet(K)
    return ld

def safe_hooklogdet(K):
    s, ld = np.linalg.slogdet(K)
    return 0 if (np.isneginf(ld) and s==0) else ld

def compute_naswot(net, inputs, layerwise=False, return_Kmats=False):
    net.eval()

    K_layer_names = []  # list of registered layer (module) names.
    if layerwise:
        K_mats = []  # list of **naswot matrix**, layer-wise | [e]([mat, ...])
        K_mats_logdet = []  # list of **naswot matrix** logdet, layer-wise | [e]([mat, ...])

        def counting_forward_hook(module, inp, out):
            out = out.view(out.size(0), -1)
            x = (out > 0).float()
            K = x @ x.t()
            K2 = (1. - x) @ (1. - x.t())
            matrix = K + K2

            K_layer_names.append(module.alias)
            K_mats.append(matrix)
            K_mats_logdet.append(safe_hooklogdet(K_mats[-1].cpu().numpy()))
    else:
        K_mat = 0.  # **naswot matrix**, NONE-layer-wise | e (mat,) ===> using torch broadcasting tech to init zero-like matrix
        K_mat_logdet = 0.  # **naswot matrix** logdet, NONE-layer-wise | e (mat,)

        def counting_forward_hook(module, inp, out):
            out = out.view(out.size(0), -1)
            x = (out > 0).float()
            K = x @ x.t()
            K2 = (1. - x) @ (1. - x.t())
            matrix = K + K2

            K_layer_names.append(module.alias)
            nonlocal K_mat
            K_mat = K_mat + matrix

    # register forward hook fn
    registered_layers = []
    for name, module in net.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d):
            module.alias = name
            module.register_forward_hook(counting_forward_hook)
            registered_layers.append(name)

    # forward pass
    with torch.no_grad():
        if isinstance(inputs , torch.Tensor) :
            net(inputs)
        if isinstance(inputs, tuple):
            net(*inputs)

    # using set instead, since some under some conditions, the list order changed.
    assert set(registered_layers) == set(K_layer_names), 'Not all module forward hook fn were triggered successfully'

    if layerwise:
        return (K_mats, K_mats_logdet) if return_Kmats else K_mats_logdet
    else:
        K_mat_logdet = safe_hooklogdet(K_mat.cpu().numpy())
        return (K_mat, K_mat_logdet) if return_Kmats else K_mat_logdet

def network_weight_gaussian_init(net: nn.Module):
    with torch.no_grad():
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                if m.weight is None:
                    continue
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.zeros_(m.bias)
            else:
                continue

    return net

def sum_arr(arr):
    sum = 0.
    if hasattr(arr, '__len__'):  # ignore some cases like str or tuple.
        for i in range(len(arr)):
            val = arr[i]
            val = val if type(val) is torch.Tensor else torch.tensor(val)
            sum += torch.sum(val)
    else:
        sum = arr
    return sum.item() if hasattr(sum,
                                 'item') else sum  # fix bug when sum is neither tensor nor numpy.float32 nor numpy.float64


def compute_synflow(model,input_tensor):
    return sum_arr(compute_synflow_per_weight(model, input_tensor))

def compute_params(model):
    return sum(p.numel() for p in model.parameters())


def compute_flops(model,input_tensor):
    if isinstance(input_tensor, torch.Tensor):
        flops, params = profile(model, inputs=(input_tensor, ), verbose=False)
    if isinstance(input_tensor, tuple):
        flops, params = profile(model, inputs=input_tensor, verbose=False)

    return flops

class ZC_proxy_class():


    def __init__(self, workspace, debug=False):
        self.workspace = workspace
        self.debug = debug

        self.simulated_data_file_name = "simulated_data_0_0"

        self.information_instance = log_class(workspace)
        self.plans=self.information_instance.read()["plans"]

        self.search_space_instance = search_space(workspace)
        self.search_space_instance.configure(mode="connector-specific hyperparameters", connector_choice=0, maximum_modeling_modules=10, maximum_data_preparation_modules=10)
        self.search_space_instance.load_logs()
        configuration,fixed_dimensions=self.search_space_instance.random_search()


        self.unit_test_instance = tester(workspace , self.plans , 0.2)
        self.unit_test_instance.configure([0,None,None],configuration=configuration,fixed_dimensions=fixed_dimensions)
        self.unit_test_instance.simulated_data_file_name=self.simulated_data_file_name
        self.variable_dimensions_arguments_string=self.unit_test_instance.variable_dimensions_arguments_string_generator()
        self.replace_batch_size()
        if self.plans[0]["connector"][2]=="integer representation of class labels":
            self.simulated_input, _,_=self.unit_test_instance.simulated_data_generation(self.variable_dimensions_arguments_string)
        else:
            self.simulated_input, _=self.unit_test_instance.simulated_data_generation(self.variable_dimensions_arguments_string)

        self.evaluation_history={}

        self.weight=[0.25,0.25,0.25,0.25]

        self.none_ranking=len(self.evaluation_history) / 2

        self.filtering_threshold=0.5


    def replace_batch_size(self , old_batch_string="batch_size=" , new_batch_size="4") :
        start_index = self.variable_dimensions_arguments_string.find(old_batch_string)
        if start_index == -1 :
            return self.variable_dimensions_arguments_string

        end_index = self.variable_dimensions_arguments_string.find(',' , start_index)
        if end_index == -1 :
            end_index = len(self.variable_dimensions_arguments_string)

        self.variable_dimensions_arguments_string=self.variable_dimensions_arguments_string[:start_index + len(old_batch_string)] + new_batch_size + self.variable_dimensions_arguments_string[end_index :]

    def load_model(self,modeling_choice):
        self.file_name="modeling_0_"+str(modeling_choice)
        exec(f"from {self.workspace} import {self.file_name}", globals())
        exec(f"importlib.reload({self.file_name})", globals())

        exec(
            f"model_raw = {self.file_name}.generate_model({self.variable_dimensions_arguments_string})" ,
            globals())

        model = model_raw.eval()

        return model

    def rank(self) :

        # remove None values from the evaluation history
        self.evaluation_history = { key : [self.none_ranking if v is None else v for v in values] for key , values in
                             self.evaluation_history.items() }

        num_metrics = len(next(iter(self.evaluation_history.values())))

        # Calculating individual ranks for each metric
        individual_ranks = defaultdict(list)
        for metric_idx in range(num_metrics) :
            # Sorting keys based on the metric value
            sorted_keys = sorted(self.evaluation_history , key=lambda x : self.evaluation_history[x][metric_idx] , reverse=True)
            # Assigning ranks
            ranks = { key : rank + 1 for rank , key in enumerate(sorted_keys) }
            for key in self.evaluation_history :
                individual_ranks[key].append(ranks[key])


        weighted_combined_ranks = { }
        for key , ranks in individual_ranks.items() :
            weighted_rank = sum(rank * weight for rank , weight in zip(ranks , self.weight))
            weighted_combined_ranks[key] = weighted_rank

        # Sorting keys based on their weighted combined ranks
        final_weighted_ranking = sorted(weighted_combined_ranks , key=weighted_combined_ranks.get)


        num_elements = math.ceil(len(final_weighted_ranking) * self.filtering_threshold)

        return final_weighted_ranking[:num_elements]

    def ZC_proxy_test(self , modeling_choice):

        if self.debug:
            print("variable_dimensions_arguments_string for ZC proxy", self.variable_dimensions_arguments_string)
            print("input data for ZC proxy", self.simulated_input)

        if self.debug:
            print(f"modeling_choice: {modeling_choice}")

        try:
            model=self.load_model(modeling_choice)
            params=compute_params(model)
            if self.debug :
                print(f"params: {params}")
        except:
            params=None
            if self.debug :
                print(f"params failed")
        try:
            model=self.load_model(modeling_choice)
            flops=compute_flops(model , self.simulated_input)
            if self.debug :
                print(f"flops: {flops}")
        except:
            flops=None
            if self.debug :
                print(f"flops failed")
        try:
            model=self.load_model(modeling_choice)
            naswot=compute_naswot(model , self.simulated_input)
            if self.debug :
                print(f"naswot: {naswot}")
        except:
            naswot=None
            if self.debug :
                print(f"naswot failed")
        try:
            model=self.load_model(modeling_choice)
            synflow=compute_synflow(model , self.simulated_input)
            if self.debug :
                print(f"synflow: {synflow}")
        except:
            synflow=None
            if self.debug :
                print(f"synflow failed")

        self.evaluation_history[modeling_choice]=[params,flops,naswot,synflow]

        return params,flops,naswot,synflow
