import os, sys
import torch
from torch import nn
import torch.nn.functional as F
import torch_pruning as tp
import timm
import warnings
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
import torchvision
from tqdm import tqdm
import scipy
import numpy as np
from pruning import importances
import argparse

from pruning import utils

def forward(self, x):
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)
    q, k = self.q_norm(q), self.k_norm(k)

    if self.fused_attn:
        x = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.attn_drop.p,
        )
    else:
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = attn @ v

    x = x.transpose(1, 2).reshape(B, N, -1) # original implementation: x = x.transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x



def prepare_imagenet(train_batch_size=64, val_batch_size=128, num_workers=18):
    print('Parsing dataset...')

    
    train_transform = timm.data.create_transform(
        input_size=224,
        is_training=True,
        color_jitter=0.3,
        auto_augment='rand-m9-mstd0.5-inc1',
        interpolation='bicubic',
        re_prob=0.25,
        re_mode='pixel',
        re_count='1',
    )

    train_dst = ImageFolder('PATH/TO/DATA/train',
                            transform= torchvision.transforms.Compose([
                                torchvision.transforms.Resize(256, interpolation=3),
                                torchvision.transforms.CenterCrop(224),
                                torchvision.transforms.ToTensor(),
                                torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                            ])
    )


    valid_dst = ImageFolder('PATH/TO/DATA/val',
                            transform = torchvision.transforms.Compose([
                                torchvision.transforms.Resize(256, interpolation=3),
                                torchvision.transforms.CenterCrop(224),
                                torchvision.transforms.ToTensor(),
                                torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                            ])
    )
    sampler_train = torch.utils.data.RandomSampler(train_dst)
    sampler_val = torch.utils.data.SequentialSampler(valid_dst)
    train_loader = torch.utils.data.DataLoader(train_dst, batch_size=train_batch_size, shuffle=True, num_workers=num_workers)
    valid_loader = torch.utils.data.DataLoader(valid_dst, batch_size=val_batch_size, shuffle=False, num_workers=num_workers)
    return train_loader, valid_loader

def validate_model(model, val_loader, device):
    model.eval()
    correct = 0
    loss = 0
    with torch.no_grad():
        for k, (images, labels) in enumerate(tqdm(val_loader)):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss += torch.nn.functional.cross_entropy(outputs, labels, reduction='sum').item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
    return correct / len(val_loader.dataset), loss / len(val_loader.dataset)

def evaluate(model, loader, device):
    loss_fn = nn.CrossEntropyLoss()
    model.eval()

    accs = 0.
    for images, target in loader:
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
        
        with torch.cuda.amp.autocast():
                output = model(images)
                loss = loss_fn(output, target)

                acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))

        batch_size = images.shape[0]
        accs += acc1 / batch_size

    accs = accs / len(loader)
    return accs

def prune_deit_network(model, Cs, imp, example_input, ratio, args):
    num_heads = {}
    ignored_layers = [model.patch_embed, model.head]
    for m in model.modules():
        if isinstance(m, timm.models.vision_transformer.Attention):
            m.forward = forward.__get__(m, timm.models.vision_transformer.Attention) # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module
            num_heads[m.qkv] = m.num_heads 


    ##### SNP #####
    prunable_layers = list(filter(lambda m: isinstance(m[1], (nn.Linear, nn.Conv2d)), model.named_modules()))
    for (name, layer), C in zip(prunable_layers, Cs):
        setattr(layer, 'C', C)


    if args.reconstruction or args.heuristic=='var':
        DG = tp.DependencyGraph().build_dependency(
            model, 
            example_inputs=example_input,
            ignored_layers=[],
            ignored_params=[],
        )

        for group in reversed(list(DG.get_all_groups(ignored_layers=ignored_layers, root_module_types=[nn.Conv2d, nn.Linear]))):
            # print(group)
            score = imp(group)
            order = torch.argsort(score, descending=True)
            order_inv = torch.argsort(order)

            for g in group:
                if isinstance(g[0].target.module, (nn.Linear, nn.Conv2d)):
                    if g[0].handler.__name__ == 'prune_in_channels':
                        setattr(g[0].target.module, 'pivots', order_inv)
                        setattr(g[0].target.module, 'order', order)
                    elif g[0].handler.__name__ == 'prune_in_features':
                        setattr(g[0].target.module, 'pivots', order_inv)
                        setattr(g[0].target.module, 'order', order)

        for name, layer in prunable_layers[1:]:
            if not hasattr(layer, 'order'): continue
            M_inv, D = utils.ldl(layer.C[layer.order][:, layer.order])

            if args.reconstruction: setattr(layer, 'lls', True)
            setattr(layer, 'M_inv', M_inv)
            setattr(layer, 'D', D)

    ###############


    pruner = tp.pruner.MetaPruner(
        model, 
        example_input, 
        global_pruning=args.heuristic == 'var',
        importance=importances.VarImportance() if args.heuristic == 'var' else imp,
        pruning_ratio=ratio,
        ignored_layers=ignored_layers,
        num_heads=num_heads,
        prune_num_heads=False,
        prune_head_dims=True,
        head_pruning_ratio=0.5, 
        round_to=1
    )


    pruner.step()
    # Modify the attention head size and all head size aftering pruning
    head_id = 0
    for m in model.modules():
        if isinstance(m, timm.models.vision_transformer.Attention):
            m.num_heads = pruner.num_heads[m.qkv]
            m.head_dim = m.qkv.out_features // (3 * m.num_heads)
            head_id+=1


parser = argparse.ArgumentParser(description="Pruning Parser", formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("--seed", type=int, default=13, required=False)
parser.add_argument("--order", type=str, help="saw or zca", required=True)
parser.add_argument("--heuristic", type=str, help="var or uniform", required=True)
parser.add_argument("--reconstruction", action=argparse.BooleanOptionalAction, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--ratio", type=float, required=True)
args = parser.parse_args()

args.device = torch.device('cuda:0')
utils.add_config_to(args)
utils.fix_settings(args.seed, torch.float32, allow_grad=True)
model = utils.get_model(args) 
train_loader, test_loader = prepare_imagenet()
loss, acc = utils.measure_perf(model, nn.CrossEntropyLoss(), test_loader, args.device)

example_input = torch.zeros(1, 3, 224, 224).to(args.device)
imp = utils.get_imp(args)
utils.make_paths(args)
Cs = utils.get_input_stats(args.stats_path + 'cross_corrs.pt', model, train_loader, args.device)

ratios = np.linspace(0, 1, 101)
stats = {
    'mac': np.zeros_like(ratios),
    'pc': np.zeros_like(ratios),
    'loss': np.zeros_like(ratios),
    'acc': np.zeros_like(ratios),
    'mac_ratio': np.zeros_like(ratios),
}

model = utils.get_model(args)
pc, flop = utils.get_network_stats(model, example_input, args.device)
prune_deit_network(model, Cs, imp, example_input, args.ratio, args)

pruned_pcs, pruned_flops = utils.get_network_stats(model, example_input, args.device)
loss, acc = utils.measure_perf(model, nn.CrossEntropyLoss(), test_loader, args.device)
print('Pruned', pruned_flops, pruned_pcs, loss, acc)

lls_spec = 'lls-' if args.reconstruction else ''
torch.save(model, f'{args.model}-{args.order}-{args.heuristic}-{lls_spec}{args.ratio}.pt')
