import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
import argparse
import os 
from models.shiftresnet import ShiftResNet20, ShiftConv
from shift_pruning import ShiftPruning
from collections import OrderedDict
import random

parser = argparse.ArgumentParser()

parser.add_argument('--model', type=str, default="shiftresnet20")
parser.add_argument('--seed', type=int, default=2022)
parser.add_argument('--out_dir', type=str, default=None)

parser.add_argument('--batch-size', default=1024, type=int)
parser.add_argument('--wd', default=0.01, type=float)
parser.add_argument('--clip-norm', action='store_true')
parser.add_argument('--epochs', default=200, type=int)
parser.add_argument('--lr-max', default=0.01, type=float)
parser.add_argument('--workers', default=8, type=int)
parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'cifar100'])

parser.add_argument('--expansion', default=4.5, type=float)
parser.add_argument('--prune', type=str, choices=['even', 'uneven'])
parser.add_argument('--fold_bn', default=False, action='store_true')

args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

mean = {
    'cifar10':  (0.4914, 0.4822, 0.4465),
    'cifar100': (0.5071, 0.4867, 0.4408),
}

std = {
    'cifar10': (0.2471, 0.2435, 0.2616),
    'cifar100': (0.2675, 0.2565, 0.2761),
}

train_transform = transforms.Compose([
     transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean[args.dataset], std[args.dataset]),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean[args.dataset], std[args.dataset])
])

if args.dataset=='cifar10':
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=train_transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=test_transform)
    n_classes=10
else:
    trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                            download=True, transform=train_transform)
    testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                        download=True, transform=test_transform)
    n_classes=100


trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
                                          shuffle=True, num_workers=args.workers)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
                                         shuffle=False, num_workers=args.workers)


def count_parameters(model):
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return params



def replace_layers(model, old, new, density=0.5, fold_bn=False, uneven=False):
    for n, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_layers(module, old, new, density=density, fold_bn=fold_bn, uneven=uneven)
            
        if isinstance(module, old):
            # if module.conv2.stride == (1, 1):
            #     setattr(model, n, new(module))
            setattr(model, n, new(module, density=density, fold_bn=fold_bn, uneven=uneven))

checkpoint = torch.load(f'output/{args.dataset}/baseline/ckpt.pth')
state_dict = checkpoint['model']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k.replace('module.', '')
    new_state_dict[name] = v
model = ShiftResNet20(expansion=9, num_classes=n_classes)
print("Number of params: ", count_parameters(model))
model.load_state_dict(new_state_dict)

if args.prune == 'even':
    replace_layers(model, ShiftConv, ShiftPruning, density=(args.expansion/9), fold_bn=args.fold_bn)
elif args.prune == 'uneven':
    replace_layers(model, ShiftConv, ShiftPruning, density=(args.expansion/9), fold_bn=args.fold_bn, uneven=True)
model.conv1.reset_parameters()
model.bn1.reset_parameters()
model.linear.reset_parameters()
print(f'Number of params after pruning ({args.prune}): {count_parameters(model)}')

model = nn.DataParallel(model).cuda()

opt = optim.AdamW(model.parameters(), lr=args.lr_max, weight_decay=args.wd)


def adjust_learning_rate(epoch, lr):
    if epoch <= 81:  # 32k iterations
      return lr
    elif epoch <= 122:  # 48k iterations
      return lr/10
    else:
      return lr/100

criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

if args.fold_bn:
    name = '_'.join([args.model , 'prune', args.prune, str(args.expansion),  'foldbn', 'epoch'+str(args.epochs),'seed'+str(args.seed)])
else:
    name = '_'.join([args.model , 'prune', args.prune, str(args.expansion),  'epoch'+str(args.epochs),'seed'+str(args.seed)])

if args.out_dir is not None:
    output_dir = os.path.join('output', args.dataset, args.out_dir, name) 
else:
    output_dir = os.path.join('output', args.dataset, name) 

if not os.path.isdir(output_dir):
    os.makedirs(output_dir)


train_acc_recorder = []
test_acc_recorder = []
best_acc = 0
for epoch in range(args.epochs):
    start = time.time()
    train_loss, train_acc, n = 0, 0, 0
    for i, (X, y) in enumerate(trainloader):
        model.train()
        X, y = X.cuda(), y.cuda()

        lr = adjust_learning_rate(epoch, args.lr_max)
        opt.param_groups[0].update(lr=lr)

        opt.zero_grad()
        with torch.cuda.amp.autocast():
            output = model(X)
            loss = criterion(output, y)

        scaler.scale(loss).backward()
        if args.clip_norm:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt)
        scaler.update()
        
        train_loss += loss.item() * y.size(0)
        train_acc += (output.max(1)[1] == y).sum().item()
        n += y.size(0)
        
    model.eval()
    test_acc, m = 0, 0
    with torch.no_grad():
        for i, (X, y) in enumerate(testloader):
            X, y = X.cuda(), y.cuda()
            with torch.cuda.amp.autocast():
                output = model(X)
            test_acc += (output.max(1)[1] == y).sum().item()
            m += y.size(0)
    
    train_acc_avg = train_acc/n 
    test_acc_avg = test_acc/m 

    if test_acc_avg > best_acc:
        print(f'Saving to {name} ...')
        state = {
            'model': model.state_dict(),
            'acc': test_acc_avg,
            'epoch': epoch,
        }
        torch.save(state, os.path.join(output_dir, 'ckpt.pth'))
        best_acc = test_acc_avg
    
    print(f'[{name}] Epoch: {epoch} | Train Acc: {train_acc_avg:.4f}, Test Acc: {test_acc_avg:.4f}, Time: {time.time() - start:.1f}, lr: {lr:.6f}')
    train_acc_recorder.append(train_acc_avg)
    test_acc_recorder.append(test_acc_avg)
   
torch.save(test_acc_recorder, os.path.join(output_dir, 'test_acc_recorder.pth'))
torch.save(train_acc_recorder, os.path.join(output_dir, 'train_acc_recorder.pth'))