import torch
from torch import nn
import torchvision
import numpy as np
import scipy
from tqdm import tqdm

import torch_pruning as tp

from fvcore.nn import FlopCountAnalysis
from fvcore.nn import parameter_count

from pathlib import Path
import random
import datetime
import time
from collections import defaultdict, deque
import yaml

from pruning.collectors import InputCorrCollector, DeitInputCorrCollector
from pruning import importances


import timm




class MetricLogger:
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(f"{name}: {str(meter)}")
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ""

        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt="{avg:.4f}")
        data_time = SmoothedValue(fmt="{avg:.4f}")
        space_fmt = ":" + str(len(str(len(iterable)))) + "d"
        if torch.cuda.is_available():
            log_msg = self.delimiter.join(
                [
                    header,
                    "[{0" + space_fmt + "}/{1}]",
                    "eta: {eta}",
                    "{meters}",
                    "time: {time}",
                    "data: {data}",
                    "max mem: {memory:.0f}",
                ]
            )
        else:
            log_msg = self.delimiter.join(
                [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
            )
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(
                        log_msg.format(
                            i,
                            len(iterable),
                            eta=eta_string,
                            meters=str(self),
                            time=str(iter_time),
                            data=str(data_time),
                            memory=torch.cuda.max_memory_allocated() / MB,
                        )
                    )
                else:
                    print(
                        log_msg.format(
                            i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
                        )
                    )
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print(f"{header} Total time: {total_time_str}")


class SmoothedValue:
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
        )



def reconstruction(layer, n_kept):
    if not (hasattr(layer, 'lls') and layer.lls): return
    if hasattr(layer, 'M_inv') and hasattr(layer, 'pivots'):
        if layer.weight.shape[1] != len(layer.M_inv): n_kept = n_kept // 49  # something more general is needed!
        M_inv = layer.M_inv.clone()
        W = layer.weight.clone()
        M = torch.inverse(M_inv)
        M = M[:n_kept]
        M_inv = M_inv[:, :n_kept] 
        A = M_inv @ M
        A = A[layer.pivots][:, layer.pivots]
        W_new = torch.einsum('ij..., jk->ik...', W, A)
        layer.weight.data = W_new.clone()


def ldl(C, epsilon=1e-6):
    C_ = C.clone()
    L = torch.linalg.cholesky(C_ + epsilon*torch.eye(len(C_), dtype=C_.dtype, device=C_.device), upper=False)

    D = L.diagonal()**2
    L = L / L.diagonal()
    
    return L, D


def sqrtm(M, epsilon=1e-12):
    return torch.tensor(
        scipy.linalg.sqrtm(
            M.detach().cpu().numpy() + np.eye(len(M)) * epsilon
        )
    ).to(device=M.device, dtype=M.dtype)

def sqrtm(M, epsilon=1e-12):
    return mpow(M, 1/2, epsilon)

def mpow(M, power, epsilon=1e-8):
    L, Q = torch.linalg.eigh(M + epsilon * torch.eye(len(M), dtype=M.dtype, device=M.device))
    return Q @ torch.diag_embed( L**(power) ) @ Q.mH


def anymul(M1, M2):
    return torch.einsum('ij..., jk->ik...', M1, M2)

def rev_cumsum(x):
    return torch.flip(torch.cumsum(torch.flip(x, dims=[0]), dim=0), dims=[0])


def _run_collector(collector, net, dataloader, device):
    collector.attach_hooks(net, nn.AdaptiveAvgPool2d)
    net.eval()
    with torch.no_grad():
        for x, _ in tqdm(dataloader):
            print(x.shape)
            x = x.to(device)
            net(x)
            break
    net.train()
    collection = collector.collect()
    collector.detach_hooks()

    return collection

def get_input_stats(path, net, dataloader, device):
    if not Path(path).exists():
        directory = str(Path(path).parent)
        Path(directory).expanduser().mkdir(parents=True, exist_ok=True)

        collector = DeitInputCorrCollector()
        Cs = _run_collector(collector, net, dataloader, device)

        torch.save(Cs, path)
        return Cs
    
    return torch.load(path, map_location=device)



def load_data(batch_size, v2=True):
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    if v2: 
        test_transform = torchvision.models.ResNet50_Weights.IMAGENET1K_V2.transforms()
    else:
        test_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(256),
                torchvision.transforms.CenterCrop(224),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean, std)
            ]
        )

    train_dataset = torchvision.datasets.ImageFolder(root="/PATH/TO/DATA/train", transform=test_transform)
    test_dataset = torchvision.datasets.ImageFolder(root="/PATH/TO/DATA/val", transform=test_transform)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=False, num_workers=18)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False, num_workers=18, drop_last=False)

    return train_dataloader, test_dataloader


def load_cifar10():
    transform_train = torchvision.transforms.Compose([
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
        # torchvision.transforms.RandomRotation(10),     #Rotates the image to a specified angel
        # torchvision.transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)), #Performs actions like zooms, change shear angles.
        # torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Set the color params
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    training_data = torchvision.datasets.CIFAR10(
        root="~/data",
        train=True,
        download=False,
        transform=transform_train
    )

    test_data = torchvision.datasets.CIFAR10(
        root="~/data",
        train=False,
        download=False,
        transform=transform_test
    )

    train_loader = torch.utils.data.DataLoader(training_data, batch_size=128, shuffle=True, drop_last=True, num_workers=18)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=False, drop_last=False)
    
    return train_loader, test_loader


def get_network_stats(net, example_input, device):
    net.eval()
    x = example_input.to(device)
    with torch.no_grad():
        analyzer = FlopCountAnalysis(net, x)
        analyzer.unsupported_ops_warnings(enabled=False)
        analyzer.uncalled_modules_warnings(enabled=False)
        flops = analyzer.total()
        pc = parameter_count(net)['']

    return pc, flops


def fix_settings(seed, fltype, allow_grad=False):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Setting floating point precision and device
    torch.set_default_dtype(fltype)

    if not allow_grad:
        # Turning off all gradient computation for pytorch (not used in models)
        torch.set_grad_enabled(False)

@torch.no_grad()
def measure_perf(net, loss_fn, dataloader, device):
    net.eval()
    
    total_loss, total_correct, total_samples = 0., 0, 0
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        
        out = net(x)
        total_loss += loss_fn(out, y).item() * len(x)
        total_correct += (out.argmax(dim=1) == y).sum().item()
        total_samples += len(x)

    avg_loss = total_loss / total_samples
    avg_acc = total_correct / total_samples

    return avg_loss, avg_acc


def get_perplexity(model, encodings):
  model.eval()
  max_length = model.config.n_positions
  stride = 512
  seq_len = encodings.input_ids.size(1)

  nlls = []
  prev_end_loc = 0
  for begin_loc in tqdm(range(0, seq_len, stride)):
      end_loc = min(begin_loc + max_length, seq_len)
      trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
      input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device)
      target_ids = input_ids.clone()
      target_ids[:, :-trg_len] = -100

      with torch.no_grad():
          outputs = model(input_ids, labels=target_ids)

          # loss is calculated using CrossEntropyLoss which averages over valid labels
          # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
          # to the left by 1.
          neg_log_likelihood = outputs.loss

      nlls.append(neg_log_likelihood)

      prev_end_loc = end_loc
      if end_loc == seq_len:
          break

  ppl = torch.exp(torch.stack(nlls).mean())
  return ppl

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.inference_mode():
        maxk = max(topk)
        batch_size = target.size(0)
        if target.ndim == 2:
            target = target.max(dim=1)[1]

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target[None])

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().sum(dtype=torch.float32)
            res.append(correct_k * (100.0 / batch_size))
        return res


def save(*args, **kwargs):
    torch.save(*args, **kwargs)


def make_paths(args):
    args.path += '' if args.path[-1] == '/' else '/'
    args.stats_path = args.path + f'results/{args.model}/{args.samples}/'
    recon = 'reconstruct' if args.reconstruction else 'hands-off'
    args.output_dir = args.path + f'results/{args.model}/{args.samples}/{args.order}/{args.heuristic}/{recon}/'
    if hasattr(args, 'ratio'): args.output_dir += str(args.ratio) + '/'
    
    if not Path(args.stats_path).exists():
        directory = str(Path(args.stats_path))
        Path(directory).expanduser().mkdir(parents=True, exist_ok=True)

    if not Path(args.output_dir).exists():
        directory = str(Path(args.output_dir))
        Path(directory).expanduser().mkdir(parents=True, exist_ok=True)


def prune_network(model, Cs, imp, example_input, ratio, args, means=None):
    prunable_layers = list(filter(lambda m: isinstance(m[1], (nn.Linear, nn.Conv2d)), model.named_modules()))
    for i, ((name, layer), C) in enumerate(zip(prunable_layers, Cs)):
        setattr(layer, 'C', C.clone())
        if means is not None:
            setattr(layer, 'mean', means[i].clone())

    ignored_layers = []
    for name, m in model.named_modules():
        if isinstance(m, torch.nn.Linear) and m.out_features == args.out_features:
            ignored_layers.append(m)
        if isinstance(m, torch.nn.Conv2d) and m.in_channels == 3 and hasattr(args, 'skip_input_prune') and args.skip_input_prune:    # According to SAW, GReg-2 and TPP, skipping the first layer improves performance
            ignored_layers.append(m)

    DG = tp.DependencyGraph().build_dependency(
        model, 
        example_inputs=example_input, 
        ignored_params=[], 
        ignored_layers=[], 
    )

    if args.reconstruction or args.heuristic == 'var':
        for group in reversed(list(DG.get_all_groups(ignored_layers=ignored_layers, root_module_types=[nn.Conv2d, nn.Linear]))):
            score = imp(group)
            order = torch.argsort(score, descending=True)
            order_inv = torch.argsort(order)

            for g in group:
                if isinstance(g[0].target.module, (nn.Linear, nn.Conv2d)):    
                    if g[0].handler.__name__ == 'prune_in_channels':
                        setattr(g[0].target.module, 'pivots', order_inv)
                        setattr(g[0].target.module, 'order', order)
                    elif g[0].handler.__name__ == 'prune_in_features':
                        setattr(g[0].target.module, 'pivots', order_inv)
                        setattr(g[0].target.module, 'order', order)

        for i, (_, layer) in enumerate(prunable_layers):
            if not hasattr(layer, 'order'): continue
            M_inv, D = ldl(layer.C[layer.order][:, layer.order])

            if args.reconstruction: setattr(layer, 'lls', True)
            setattr(layer, 'M_inv', M_inv)
            setattr(layer, 'D', D)

    if args.heuristic == 'var':
        global_importance = importances.VarImportance()
    else:
        global_importance = imp

    pruner = tp.pruner.MetaPruner(
        model,
        example_input,
        importance=global_importance,
        pruning_ratio=ratio,
        ignored_layers=ignored_layers,
        global_pruning=args.heuristic == 'var' or args.heuristic == 'emvar' or args.heuristic == 'varsum',
    )

    pruner.step()


def add_config_to(args):
    config = yaml.safe_load(Path(args.config).read_text())
    for key, value in config.items():
        setattr(args, key, value)


def get_model(args):
    if args.model == 'resnet50':
        return torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1).to(args.device)
    elif args.model == 'resnet18':
        return torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1).to(args.device)
    elif args.model == 'resnet50_v2':
        return torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2).to(args.device)
    elif args.model == 'vgg16':
        return torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1).to(args.device)
    elif args.model == 'deit_tiny':
        return timm.create_model("deit_tiny_patch16_224", pretrained=True).eval().to(args.device)
    elif args.model == 'deit_small':
        return timm.create_model("deit_small_patch16_224", pretrained=True).eval().to(args.device)
    else:
        raise NotImplementedError('Simply add a model to extend this framework.')
    

def get_data(args):
    if args.dataset == 'imagenet':
        train_loader, test_loader = load_data(batch_size=args.batch_size, v2=False)
        example_input = torch.randn(1, 3, 224, 224).to(args.device)
        args.out_features=1000
    elif args.dataset == 'imagenet_v2':
        train_loader, test_loader = load_data(batch_size=256, v2=True)
        example_input = torch.randn(1, 3, 224, 224).to(args.device)
        args.out_features=1000
    elif args.dataset == "cifar10":
        train_loader, test_loader = load_cifar10()
        example_input = torch.randn(1, 3, 32, 32).to(args.device)
        args.out_features=10
    else:
        raise NotImplementedError()
    
    return train_loader, test_loader, example_input

def get_imp(args):
    if args.order == 'saw':
        return tp.importance.MagnitudeImportance(p=1, normalizer=None, group_reduction="first")
    elif args.order == 'zca':
        return importances.ZCAImportance()
    else:
        raise NotImplementedError()
    

def reinit(net):
    for layer in net.children():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()
        else:
            for child in layer.children():
                reinit(child)
