import torch
from torch import nn

from fvcore.nn import FlopCountAnalysis
from fvcore.nn import parameter_count

import numpy as np
import random
import scipy
from tqdm import tqdm

import os
from types import SimpleNamespace

from pruning.nethooks import InputCorrCollector, InputMeanCollector
from pruning.modules import FactPrunableDemeanConv2d, FactPrunableLayer, ElementwisePruningLayer, Prunable



def fix_settings(seed, fltype, allow_grad=False):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Setting floating point precision and device
    torch.set_default_dtype(fltype)

    if not allow_grad:
        # Turning off all gradient computation for pytorch (not used in models)
        torch.set_grad_enabled(False)


def measure_perf(net, loss_fn, dataloader, device):
    net.eval()

    loss, acc = 0., 0.
    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)

        with torch.no_grad():
            out = net(x)
            loss += loss_fn(out, y).item()
            acc += torch.mean((out.argmax(dim=1) == y).to(torch.float32)).item()

    return loss/len(dataloader), acc/len(dataloader)

def get_perplexity(model, encodings):
  model.eval()
  max_length = model.config.n_positions
  stride = 512
  seq_len = encodings.input_ids.size(1)

  nlls = []
  prev_end_loc = 0
  for begin_loc in tqdm(range(0, seq_len, stride)):
      end_loc = min(begin_loc + max_length, seq_len)
      trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
      input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device)
      target_ids = input_ids.clone()
      target_ids[:, :-trg_len] = -100

      with torch.no_grad():
          outputs = model(input_ids, labels=target_ids)

          # loss is calculated using CrossEntropyLoss which averages over valid labels
          # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
          # to the left by 1.
          neg_log_likelihood = outputs.loss

      nlls.append(neg_log_likelihood)

      prev_end_loc = end_loc
      if end_loc == seq_len:
          break

  ppl = torch.exp(torch.stack(nlls).mean())
  return ppl


def ldl(C):
    LU, D, pivots = scipy.linalg.ldl(C.detach().cpu().numpy())
    LU = LU[pivots, :]
    D = D[pivots]
    R_inv = torch.tensor(LU).to(C.device)
    return R_inv, torch.tensor(np.diag(D), device=C.device)


def combine_conv_weights(w1, w2):
    """ Order of w1 and w2 is flipped compared to matmuls """
    return torch.conv2d(
        w1.permute(1, 0, 2, 3),
        w2.flip(-1, -2),  # convolution vs cross-correlation
        padding=w2.shape[-1] - 1
    ).permute(1, 0, 2, 3)


def combine_weights(mod, R_comb): 
    W = mod.weight.clone()
    if len(W.shape) > 2:
        W_new = combine_conv_weights(R_comb[..., None, None], W)
    else: 
        if len(R_comb) != mod.in_features:
            width = height = int(np.sqrt(mod.in_features // len(R_comb)))  # compute the width and height of the input before the flatten operation 
            W = torch.reshape(W, (-1, len(R_comb), height, width))         # reshape W to the input shape before the flattening

            W_new = combine_conv_weights(R_comb[..., None, None], W) 
            W_new = torch.flatten(W_new, start_dim=1)
        else:
            W_new = W @ R_comb

    mod.weight.data = W_new.clone()


def run_collector(collector, net, dataloader, device):
    collector.attach_hooks(net, nn.AdaptiveAvgPool2d)
    net.eval()
    with torch.no_grad():
        for x, _ in tqdm(dataloader):
            x = x.to(device)
            net(x)
    net.train()
    collection = collector.collect()
    collector.detach_hooks()

    return collection


def get_input_stats(net, dataloader, device, demean=False, args=SimpleNamespace()):
    path = args.stats_path + f'{args.pruning_algo}.pt'
    if not os.path.exists(path):
        means = None
        if demean:
            collector = InputMeanCollector()
            means = run_collector(collector, net, dataloader, device)

        collector = InputCorrCollector(means)
        Cs = run_collector(collector, net, dataloader, device)

        torch.save([means, Cs], path)
        return means, Cs
    
    return torch.load(path, map_location=device)

def get_input_dims(device, args=SimpleNamespace()):
    path = args.stats_path + 'dec.pt'
    if os.path.exists(path):
        _, Cs = torch.load(path, map_location=device)

        return [len(C) for C in Cs]
    return None



def factory(layer, R=None, mean=None, pivots=None, pre_dim=None, args=SimpleNamespace()):
    if R is not None:
        if isinstance(layer, nn.Conv2d) and mean is not None: return FactPrunableDemeanConv2d(layer, R, mean, pivots, prune_by_removal=args.prune_by_removal)  # cutting out computation with demeaning or wit residual connections is not possible yet!
        else: return FactPrunableLayer(layer, R, mean, pivots, prune_by_removal=args.prune_by_removal)
    else:
        return ElementwisePruningLayer(layer, pivots, pre_dim=pre_dim, prune_by_removal=args.prune_by_removal)  # troublesome for last layer (flatten -> classification)


def replace_layers(net, Rs=None, means=None, pivots=None, args=SimpleNamespace()):
    prunable_layers = []

    def _replace_layers(module, out_dim=None):
        for name, mod in module.named_children():
            if isinstance(mod, nn.Linear) or isinstance(mod, nn.Conv2d):
                mod_new = factory(
                    mod, 
                    Rs[len(prunable_layers)] if Rs is not None else None, 
                    means[len(prunable_layers)] if means is not None else None,
                    pivots[len(prunable_layers)] if pivots is not None else None,
                    pre_dim=out_dim if args.pre_dims is None else args.pre_dims[out_dim],
                    args=args
                )
                prunable_layers.append(mod_new)
                module.__setattr__(name, mod_new)
                out_dim = len(mod.weight) if args.pre_dims is None else out_dim+1
            else:
                if mod.children() is not None: out_dim = _replace_layers(mod, out_dim)

        return out_dim
            
    _replace_layers(net, None if args.pre_dims is None else 0)
    return prunable_layers


def revert_layers(net):
    for name, mod in net.named_children():
        if isinstance(mod, Prunable):
            mod_new = mod.layer
            net.__setattr__(name, mod_new)
        else:
            if mod.children() is not None: revert_layers(mod)


def get_prunable_layers(net):
    prunable_layers = []

    def _find_layers(module):
        for _, mod in module.named_children():
            if isinstance(mod, nn.Linear) or isinstance(mod, nn.Conv2d):
                prunable_layers.append(mod)
            else:
                if mod.children() is not None: _find_layers(mod)

    _find_layers(net)
    return prunable_layers


def saw_ordering(net):
    prunable_layers = list( filter(lambda mod: isinstance(mod, nn.Linear) or isinstance(mod, nn.Conv2d), net.modules()) )
    return [torch.argsort(torch.sum(torch.flatten(torch.abs(mod.weight), start_dim=1), dim=1), descending=True) for mod in prunable_layers]


def saw_tilde_ordering(net, Cs):
    def _combine_weights(mod, M): 
        W = mod.weight.clone()
        if len(W.shape) > 2:
            W_new = combine_conv_weights(M[..., None, None], W)
        else: 
            if len(M) != mod.in_features:
                width = height = int(np.sqrt(mod.in_features // len(M)))  # compute the width and height of the input before the flatten operation 
                W = torch.reshape(W, (-1, len(M), height, width))         # reshape W to the input shape before the flattening

                W_new = combine_conv_weights(M[..., None, None], W) 
                W_new = torch.flatten(W_new, start_dim=1)
            else:
                W_new = W @ M

        return W_new

    prunable_layers = list( filter(lambda mod: isinstance(mod, nn.Linear) or isinstance(mod, nn.Conv2d), net.modules()) )
    Ms = [torch.tensor(
            scipy.linalg.sqrtm(C.detach().cpu() + 1e-8 * torch.eye(len(C)))
        ).to(C.device).to(torch.float32) for C in Cs]
    
    D_uZCAs = [1/torch.diag(torch.inverse(M)) for M in Ms]
    rets = [
        torch.argsort(
            torch.sum(
                torch.flatten( torch.abs(_combine_weights(mod, M)), start_dim=1 ),
                dim=1
            ) + D_uZCA, descending=True
        ) for mod, M, D_uZCA in zip(prunable_layers, Ms, D_uZCAs)
    ]
    return rets


def ZCA_ordering(Cs):
    layerwise_ordering = []
    for C in Cs:
        D = C.detach().cpu()
        D = scipy.linalg.sqrtm(D + 1e-8 * np.eye(len(D)))
        D = np.linalg.inv(D)
        D = np.diag(D)
        D = D * D
        D = 1/D
        D = torch.tensor(D).to(C.device)
        
        pivots = torch.argsort(D, descending=True)
        layerwise_ordering.append(pivots)
    return layerwise_ordering

def random_ordering(Cs):
    return [torch.randperm(len(C)).to(C.device) for C in Cs]

def C_ordering(Cs):
    return [torch.argsort(torch.abs(C).sum(dim=1) / torch.diag(C), descending=False) for C in Cs]


def network_wide_saw(net):
    prunable_layers = list( filter(lambda mod: isinstance(mod, nn.Linear) or isinstance(mod, nn.Conv2d), net.modules()) )[:-1]
    res = []
    for indx, mod in enumerate(prunable_layers):
        scores = torch.sum(torch.flatten(torch.abs(mod.weight), start_dim=1), dim=1)
        res.extend(
            list(
                zip(
                    [indx for _ in range(len(scores))],                   # layer index
                    range(len(scores)),                                   # neuron index
                    (scores / torch.max(scores)).detach().cpu().tolist()  # score 
                )
            )
        )
    res.sort(key=lambda x: x[-1])
    return np.array(res)


def network_wide_acc(net, accs):
    ratios = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    res = []
    prunable_layers = list(filter(lambda mod: isinstance(mod, nn.Linear) or isinstance(mod, nn.Conv2d), net.modules()))[:-1]
    for indx, (s, mod) in enumerate(zip(accs, prunable_layers)):
        units = mod.out_features if isinstance(mod, nn.Linear) else mod.out_channels
        p = 0
        repeats = []
        for r in ratios:
            m = int(units * r)
            repeats.append(m - p)
            p = m

        scores = np.flip(np.repeat(s, repeats))
        res.extend(
            list(
                zip(
                    [indx for _ in range(len(scores))],                   # layer index
                    range(len(scores)),                                   # neuron index
                    scores                                                # score 
                )
            )
        )
    
    res.sort(key=lambda x: x[-1], reverse=True)
    return np.array(res)

    
def invert_ordering(orders):
    return [torch.argsort(order) for order in orders]


def prune_pre_layer(mod, pruning_indices):
    W = mod.weight.clone()
    old_weight_value = W.clone()
    W[pruning_indices] = 0.
    mod.weight.data = W

    if mod.bias is not None:
        b = mod.bias.clone()
        old_bias_value = b.clone()
        b[pruning_indices] = 0.
        mod.bias.data = b

        return [(mod.weight, old_weight_value), (mod.bias, old_bias_value)]
    return [(mod.weight, old_weight_value)]


def prune_post_layer(mod, indices):
    W = mod.weight.clone()
    if len(W_post.shape) != len(W.shape):  # in case there is a flatten operation in-between
        W_post = torch.reshape(W_post, (-1, len(W), mod.in_features // len(W)))  
        W_post[:, indices] = 0.
        W_post = torch.flatten(W_post, start_dim=1, end_dim=2)
    else:
        W_post[:, indices] = 0.
    mod.weight.data = W_post


def get_n_nodes_for_variance_cutoff(variances, ratio):
    # normalize:
    D = variances / torch.sum(variances)
    
    # get all cumulative sums from the right
    cumulative_sum = np.cumsum(D.detach().cpu().numpy()[::-1])

    # select the first one exceeding the ratio
    required_count = np.argmax(cumulative_sum >= ratio)

    return required_count


def get_network_stats(net, dataloader, device):
    net.eval()
    x, _ = next(iter(dataloader))
    x = x[:1].to(device)
    with torch.no_grad():
        analyzer = FlopCountAnalysis(net, x)
        analyzer.unsupported_ops_warnings(enabled=False)
        analyzer.uncalled_modules_warnings(enabled=False)
        flops = analyzer.total()
        pc = parameter_count(net)['']

    return pc, flops


def parameter_count_nonzero(net):
    s = 0
    for _, prm in net.named_parameters():
        s += torch.sum(prm != 0).item()
    return s


def make_paths(args):
    args.base_dir = f'{args.output_path}{args.model}/'
    args.stats_path = args.base_dir + 'stats/'
    args.results_path = args.base_dir+'results/'
    
    if not os.path.exists(args.base_dir):
        os.mkdir(args.base_dir)

    if not os.path.exists(args.stats_path):
        os.mkdir(args.stats_path)
    if not os.path.exists(args.results_path):
        os.mkdir(args.results_path)










