import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import random

from models import *
from utils import progress_bar

from Tensor2DS import *

from collections import OrderedDict

parser = argparse.ArgumentParser()

### Tensor2DS related argument
parser.add_argument('--order', default='CPD_svd', type=str,
                    help='decide the type of decomposition',
                    choices=['DW_PW', 'PW_DW', 'CPD_svd', 'CPD_tpm'])
parser.add_argument('--rank', default=4, type=int,
                    help='decide the rank for CP decomposition')
parser.add_argument('--init', default='cpd', type=str, choices=['random', 'cpd'],
                    help='whether or not to use CPD initialization for PDP')

# general
parser.add_argument('--epoch', default=300, type=int, help="Training epochs")
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true',
                    help='resume from checkpoint')
parser.add_argument('--seed', type=int, default=2022)
parser.add_argument('--model', type=str, default='vgg16', choices=['vgg16', 'resnet34'])
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(args.seed)

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root="data", train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(
    root="data", train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=128, shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
if args.model == 'vgg16':
    net = VGG('VGG16')
elif args.model == 'resnet34':
    net = ResNet34()
else:
    raise NotImplementedError

original = sum([param.numel() for param in net.parameters() if param.requires_grad])
print("Number of parameter (Baseline): %d" % (original))

def replace_layers(model, pretrained, init='cpd'):
    for n, module in model.named_children():
        pretrained_module = getattr(pretrained, n)
        if len(list(module.children())) > 0:
            replace_layers(module, pretrained_module, init=init)
            
        if isinstance(module, nn.Conv2d):
            if module.kernel_size == (3, 3): 
                print(f'Replacing layer ...')
                if init =='cpd':
                    PDP_kernels = Tensor2DS_decomp(pretrained_module, order=args.order, rank=args.rank)
                else:
                    c_in = pretrained_module.in_channels
                    k = pretrained_module.kernel_size[0]
                    c_out = pretrained_module.out_channels
                    is_bias = True if getattr(pretrained_module, 'bias') is not None else False
                    if is_bias:
                        PDP_kernels = nn.Sequential(nn.Conv2d(c_in, args.rank, 1),
                                    nn.Conv2d(args.rank, args.rank, k, groups=args.rank, padding=pretrained_module.padding, stride=pretrained_module.stride),
                                    nn.Conv2d(args.rank, c_out, 1))
                    else:
                        PDP_kernels = nn.Sequential(nn.Conv2d(c_in, args.rank, 1, bias=False),
                                    nn.Conv2d(args.rank, args.rank, k, groups=args.rank, padding=pretrained_module.padding, stride=pretrained_module.stride, bias=False),
                                    nn.Conv2d(args.rank, c_out, 1, bias=False))
                setattr(model, n, PDP_kernels)
            


if args.init == 'cpd':
    # load params for each layer
    ckpt = os.path.join('checkpoint', args.model, 'baseline/ckpt.pth')
    checkpoint = torch.load(ckpt)
    state_dict = checkpoint['net']

    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]
        new_state_dict[name] = v
    print("Load pretrained weights..")

    if args.model == 'vgg16':
        pretrained = VGG('VGG16')
        layers = [0,3,7,10,14,17,20,24,27,30,34,37,40]
        pretrained.load_state_dict(new_state_dict)
        for layer in layers:
            print(f'Replacing layer {layer}...')
            net.features[layer] = Tensor2DS_decomp(pretrained.features[layer], order=args.order, rank=args.rank)
    elif args.model == 'resnet34':
        pretrained = ResNet34()
        pretrained.load_state_dict(new_state_dict)
        replace_layers(net, pretrained)
else:
    if args.model == 'vgg16':
        layers = [0,3,7,10,14,17,20,24,27,30,34,37,40]
        for layer in layers:
            print(f'Replacing layer {layer}...')
            conv_layer = net.features[layer]
            c_in = conv_layer.in_channels
            k = conv_layer.kernel_size[0]
            c_out = conv_layer.out_channels
            net.features[layer] = nn.Sequential(
                        nn.Conv2d(c_in, args.rank, 1),
                        nn.Conv2d(args.rank, args.rank, k, groups=args.rank, padding=conv_layer.padding, stride=conv_layer.stride),
                        nn.Conv2d(args.rank, c_out, 1))
    elif args.model == 'resnet34':
        replace_layers(net, net, init=args.init)



net = net.to(device)
print(net)
total = sum([param.numel() for param in net.parameters() if param.requires_grad])
print("Number of parameter (CPD): %d" % (total))
print("CR (%): ", ((original-total)/original)*100)


if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.pth')
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr,
                    momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epoch)


# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))

    # Save checkpoint.
    acc = 100. * correct / total
    acc_recorder.append(acc)


    if args.order in ['CPD_svd', 'CPD_tpm']:
        if args.init == 'cpd':
            name = ('_').join([args.order, 'rank'+str(args.rank), 'seed'+str(args.seed)])
        elif args.init == 'random':
            name = ('_').join(['Random', 'rank'+str(args.rank), 'seed'+str(args.seed)])
    else:
        name = ('_').join([args.order, 'seed'+str(args.seed)])


    output_dir = os.path.join('checkpoint', args.model, name)

    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)
        torch.save(state, os.path.join(output_dir, 'ckpt.pth'))
        best_acc = acc

    torch.save(acc_recorder, os.path.join(output_dir, 'acc_recorder.pth'))


global acc_recorder
acc_recorder = []

for epoch in range(start_epoch, start_epoch + args.epoch):
    train(epoch)
    test(epoch)
    scheduler.step()
