import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import argparse
import os

from dataloading import MNISTDataset, torch_train_val_split
from models import vgg_bn_drop, vgg_bn_drop_100
from training import train, eval
from pruning import pruning_experiment

device = 'cuda' if torch.cuda.is_available() else 'cpu'

parser = argparse.ArgumentParser(description='Run third pruning experiment.')
parser.add_argument('--model', type=str, default='CIFAR-VGG',
                    help='CNN model. Possible Options: "CIFAR-VGG", "AlexNet"')
parser.add_argument('--dataset', default='CIFAR10',
                    help='Dataset used. Possible options: "CIFAR10", "CIFAR100"')
parser.add_argument('--image_size', type=int, default=32,
                    help='Input image size. Default 32')
parser.add_argument('--epochs', type=int, default=28,
                    help='Number of training epochs. Default 28')
args = parser.parse_args()

################# Configuration  ######################
# Dataset selection
DATASET = args.dataset 
IMAGE_SIZE = args.image_size

# Training parameters 
BATCH_SIZE = 64
EPOCHS = args.epochs
EXECUTIONS = 5 # How many models to train in total
PATH = 'models/'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Model Parameters
MODEL = args.model


############# Datasets and Dataloaders ################
transform_train = transforms.Compose([
    transforms.Resize(size=(IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(size=(IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

if DATASET == "CIFAR10":
    train_set = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    test_set = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)

    output_size = 10

elif DATASET == "CIFAR100":
    train_set = torchvision.datasets.CIFAR100(
        root='./data', train=True, download=True, transform=transform_train)
    test_set = torchvision.datasets.CIFAR100(
        root='./data', train=False, download=True, transform=transform_test)

    output_size = 100

train_data = train_set
     
# Defining DataLoaders
train_loader, val_loader = torch_train_val_split(train_set, 
                                                    BATCH_SIZE, 
                                                    val_size=.25, 
                                                    shuffle=True)
test_loader = torch.utils.data.DataLoader(
            test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) 


############ Model, Criterion, Optimizer ##############
for i in range(EXECUTIONS): # We execute the experiment in many models to examine its stability

    if MODEL == 'ConvCIFAR-VGG':
        model = vgg_bn_drop(pretrained=True) if DATASET == "CIFAR10" else vgg_bn_drop_100(pretrained=False)
        # for param in model.features.parameters():
        #     param.requires_grad = False
        
    else:
        raise ValueError("Invalid model's name")

    # move the mode weight to cpu or gpu
    model = model.to(device)
    print(model)

    # Criterion and optimizer selection
    criterion = nn.CrossEntropyLoss()
    parameters = model.parameters()
    optimizer = torch.optim.SGD(parameters, lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

###################### Training  ######################
    if os.path.exists(PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i)):
        model = torch.load(PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i))
    else:
        model = train(model,
                    EPOCHS,
                    optimizer,
                    criterion,
                    scheduler,
                    train_loader,
                    val_loader)
                
        torch.save(model, PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i))

    accuracy = eval(model, test_loader, criterion)
    print('Test Accuracy of model is {:.2f}'.format(100 * accuracy))

from models import *

def fuse_layer(layer, batch_norm):
    if isinstance(layer, nn.Linear):
        # Initialize the new Linear layer with the same dimensions
        new_layer = nn.Linear(layer.in_features, layer.out_features)
    elif isinstance(layer, nn.Conv2d):
        # Initialize the new Conv2d layer with the same dimensions
        new_layer = nn.Conv2d(
            in_channels=layer.in_channels,
            out_channels=layer.out_channels,
            kernel_size=layer.kernel_size,
            stride=layer.stride,
            padding=layer.padding,
            dilation=layer.dilation,
            groups=layer.groups,
            bias=True  # Bias should be true for the fused layer
        )
    else:
        raise ValueError("Unsupported layer type")

    # BatchNorm parameters
    mean = batch_norm.running_mean
    std = torch.sqrt(batch_norm.running_var + batch_norm.eps)
    coeff = batch_norm.weight
    bias = batch_norm.bias

    # Compute the scaling factor and new bias
    new_mult = coeff / std
    new_bias = bias - mean * coeff / std

    # Set the weights and biases of the new layer
    if isinstance(layer, nn.Linear):
        new_layer.weight.data = layer.weight.data * new_mult
        new_layer.bias.data = layer.bias.data * new_mult + new_bias
    elif isinstance(layer, nn.Conv2d):
        # Adjust for Conv2d, where the weights need to be adjusted by the scaling factor
        new_layer.weight.data = layer.weight.data * new_mult.view(-1, 1, 1, 1)
        if layer.bias is not None:
            new_layer.bias.data = layer.bias.data * new_mult + new_bias
        else:
            new_layer.bias.data = new_bias

    return new_layer

def approximate_conv_bn(conv_layer, bn_layer, num_epochs=500, lr=0.01):
    # Create a new convolutional layer with the same parameters as the original conv layer
    new_conv_layer = nn.Conv2d(
        in_channels=conv_layer.in_channels,
        out_channels=conv_layer.out_channels,
        kernel_size=conv_layer.kernel_size,
        stride=conv_layer.stride,
        padding=conv_layer.padding,
        dilation=conv_layer.dilation,
        groups=conv_layer.groups,
        bias=True
    ).to(conv_layer.weight.device)
    
    # Create models
    model_with_bn = nn.Sequential(conv_layer, bn_layer).to(conv_layer.weight.device)
    model_new_conv = nn.Sequential(new_conv_layer).to(conv_layer.weight.device)

    # Freeze the original model parameters
    for param in model_with_bn.parameters():
        param.requires_grad = False

    # Define the optimizer for the new convolutional layer
    optimizer = torch.optim.Adam(new_conv_layer.parameters(), lr=lr)
    criterion = nn.MSELoss()

    # Training loop
    for epoch in range(num_epochs):
        # Generate random noise input
        input_data = torch.randn(16, conv_layer.in_channels, 32, 32).to(conv_layer.weight.device)
        
        # Forward pass through the original model
        with torch.no_grad():
            target_output = model_with_bn(input_data)
        
        # Forward pass through the new model
        output = model_new_conv(input_data)
        
        # Compute the loss
        loss = criterion(output, target_output)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            print(f'Epoch [{epoch}/{num_epochs}], Loss: {loss.item():.4f}')

        if loss.item() < 0.0001:
            break

    return new_conv_layer

def expand_convbnrelu(module):
    layers = []
    for name, child in module.named_children():
        if isinstance(child, ConvBNReLU):
            # new_conv = approximate_conv_bn(child.conv, child.bn)
            new_conv = fuse_layer(child.conv, child.bn)
            layers.append(new_conv)
            layers.append(child.relu)
        else:
            layers.append(child)
    return nn.Sequential(*layers)

def expand_convbnrelu_withbn(module):
    layers = []
    for name, child in module.named_children():
        if isinstance(child, ConvBNReLU):
            layers.append(child.conv)
            layers.append(child.bn)
            layers.append(child.relu)
        else:
            layers.append(child)
    return nn.Sequential(*layers)

class NewVGGBnDrop(nn.Module):
    def __init__(self, model):
        super(NewVGGBnDrop, self).__init__()
        self.features = expand_convbnrelu(model.features)
        self.classifier = model.classifier

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor for the classifier
        x = self.classifier(x)
        return x
    
class NewVGGBnDrop_withbn(nn.Module):
    def __init__(self, model):
        super(NewVGGBnDrop_withbn, self).__init__()
        self.features = expand_convbnrelu_withbn(model.features)
        self.classifier = model.classifier

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor for the classifier
        x = self.classifier(x)
        return x

############## Compression Experiment  ###############
info = {
    'name' : MODEL,
    'dataset' : DATASET,
    'imsize' : IMAGE_SIZE,
    'ratios' : [1, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7],
    'repetitions' : 5 # How many times to repeat compression algorithm
}
for i in range(EXECUTIONS): # We execute the experiment in many models to examine its stability
    
    MODEL = "CIFAR-VGG"
    model = torch.load(PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i), map_location=DEVICE)

    formatted_model_path = PATH + '{}-Formatted_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i)
    import os
    if os.path.exists(formatted_model_path):
        model = torch.load(formatted_model_path)
    else:
        model = NewVGGBnDrop(model)
        torch.save(model, formatted_model_path)

    # model = NewVGGBnDrop_withbn(model)

    MODEL = "ConvCIFAR-VGG"

    # move the mode weight to cpu or gpu
    model = model.to(device)
    print(model)

    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='neural_path_kmeans')
    # pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
    #                    method='tropnnc')
    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='iterative_tropnnc')
    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='thinet', dataset=train_data, w2_rescale=True)
    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='random_structured')
    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='l1_structured')
