import torch
import torch.nn as nn
import torch.optim as optim
from src.data.data_utils import choose_dataset
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--r", dest="start_rank_percent", type=float, default=0.4)
parser.add_argument("--m", dest="model", type=str, default='resnet')
parser.add_argument("--fact", dest="factorization", type=str, default='tucker')
parser.add_argument("--d", dest="data", type=str, default='cifar10')
options = parser.parse_args()


# Set the device (GPU or CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the hyperparameters
batch_size = 64
learning_rate = 0.05
num_epochs = 100
momentum = 0.9
weight_decay = 5e-4

if options.data.lower() == 'tiny_imagenet':   ### training hyperparameters for tiny imagenet
    learning_rate = 0.01
    num_epochs = 50
    momentum = 0.9
    weight_decay = 5e-3
    batch_size = 128

import src.low_rank_neural_networks.__init__ as g
g.factorization = options.factorization 
g.glob_start_rank_perc = float(options.start_rank_percent)

if 'resnet' in options.model.lower():
    import src.low_rank_neural_networks.vanillas.vanilla_ResNet as lr_vanilla
elif 'vgg' in options.model.lower():
    import src.low_rank_neural_networks.vanillas.vanilla_VGG as lr_vanilla
elif 'alexnet' in options.model.lower():
    import src.low_rank_neural_networks.vanillas.vanilla_alexnet as lr_vanilla

# -------- Dataset Selection -----------
if options.data.lower() == 'imagenet':
    num_classes = 1000
    datapath = "./imageNet/"
elif options.data.lower() == 'tiny_imagenet':
    num_classes = 200
    datapath = "./tiny_imagenet/"
else:
    num_classes = 10
    datapath = "./data/"

train_loader, val_loader, test_loader = choose_dataset(dataset_name=options.data.lower(), batch_size=batch_size,
                                                       datapath=datapath)


# Initialize the model
if options.model.lower() == 'vgg':
    model = lr_vanilla.vgg16().to(device)
    print("Train VGG16")
elif options.model.lower() == 'resnet':
    model = lr_vanilla.resnet18().to(device)
    print("Train ResNet18")
elif options.model.lower() == 'resnet50':
    model = lr_vanilla.resnet50().to(device)
    print("Train ResNet50")
elif options.model.lower() == 'alexnet':
    model = lr_vanilla.alexnet().to(device)
    print("Train AlexNet")

print(f' factorization : {lr_vanilla.factorization},dataset {options.data.lower()},model {options.model}')

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum,weight_decay = weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25,40], gamma=0.1)

if options.data.lower() == 'tiny_imagenet':
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.02, steps_per_epoch=len(train_loader),
                       epochs=num_epochs, div_factor=10, final_div_factor=10,
                       pct_start=10/num_epochs)

ranks = []
max_ranks = []
lr_weights_UnCs = []
for lr_layer in lr_vanilla.low_rank_layers:
    for u in lr_layer.Us:
        lr_weights_UnCs.append(u)
    if options.factorization !='mat':
        lr_weights_UnCs.append(lr_layer.C)
    ranks.append(lr_layer.rank)
    max_ranks.append(min([u.shape[0] for u in lr_layer.Us]))

total_params_compressed = 0
with torch.no_grad():
    for p in model.parameters():
        total_params_compressed+=int(torch.prod(torch.tensor(p.shape)))

# Training loop
print(f'train vanilla with params {options.start_rank_percent}')
print(f'ranks: {ranks}\n')
print(f'max ranks {max_ranks}\n')
print(f'total params :{total_params_compressed}')
print('='*100)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for batch_idx, (data, targets) in enumerate(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        for p in model.parameters():
            p.grad = None

        # Forward pass
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss+=float(loss.item())

        if (batch_idx + 1) % 100 == 0:
            print(
                f"Epoch [{epoch + 1}/{num_epochs}], Step [{batch_idx + 1}/{len(train_loader)}], Loss: {loss.item():.4f}")
    scheduler.step()#(total_loss)
    # Test the model
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data, targets in val_loader:
            data = data.to(device)
            targets = targets.to(device)

            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        accuracy = 100 * correct / total
        print(f"Accuracy of the network on the test images: {accuracy}%")
print("Training finished.")
