import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import sys
import random
import copy
import wandb
import yaml
from dataset.cifar10 import CIFAR10, LazyCIFAR10
from model.resnet import ResNet18WithCompression

# Load configuration
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Set random seed
seed = config['seed']
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

use_wandb = config['use_wandb']

print("Python Version: ", sys.version)
print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)
print("WandB version: ", wandb.__version__)

compression_config = config['compression_config']
training_config = config['training']

if use_wandb:
    run = wandb.init(
        project=config['wandb']['project'],
        name=config['wandb']['name'],
        config={**training_config, 'compression_config': compression_config, 'seed': seed},
    )
else:
    run = None

with_idx = training_config["AC-SGD"]
batch_size = training_config["batch_size"]

if training_config["lazy_sampling"]:
    if training_config["lazy_sampling_params"]["schedule"] == "constant":
        train_dataset = LazyCIFAR10(with_idx=with_idx, p_t=(lambda x: training_config["lazy_sampling_params"]["p_t"]))
        trainloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True
        )
else:
    trainloader = torch.utils.data.DataLoader(
        CIFAR10(with_idx=with_idx),
        batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True
    )

testloader = torch.utils.data.DataLoader(
    CIFAR10(train=False, with_idx=with_idx),
    batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True
)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.get_device_name(device))

net = ResNet18WithCompression(compression_config)
net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=training_config['learning_rate'],
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

best_acc = 0
best_acc_compressed = 0

def train(epoch):
    print('\nEpoch: %d' % epoch)
    if training_config["lazy_sampling"]:
        train_dataset.update_epoch(epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs, compress=True)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx % 100 == 99:
            print('Train: Loss: %.3f | Acc: %.3f%% (%d/%d)'  % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    acc = 100.*correct/total
    wandb.log({
        'train_loss': train_loss,
        'train_acc': acc
    }, commit=False)

def test(epoch):
    global best_acc, best_acc_compressed
    net.eval()
    test_loss = 0
    test_loss_compressed = 0
    correct = 0
    correct_compressed = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # no compression test
            outputs = net(inputs, compress=False)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            
            # compressed test
            outputs = net(inputs, compress=True)
            loss = criterion(outputs, targets)
            test_loss_compressed += loss.item()
            _, predicted = outputs.max(1)
            correct_compressed += predicted.eq(targets).sum().item()
            
            total += targets.size(0)
    acc = 100.*correct/total
    acc_compressed = 100.*correct_compressed/total
    print('Test Uncompressed: Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/len(testloader), acc, correct, total))
    print('Test Compressed: Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss_compressed/len(testloader), acc_compressed, correct_compressed, total))

    wandb.log({
        'test_loss': test_loss,
        'test_acc': acc,
        'test_loss_compressed': test_loss_compressed,
        'test_acc_compressed': acc_compressed,
        'epoch': epoch
    })
    if acc > best_acc:
        best_acc = acc
        print(f'New best acc {acc}')
    if acc_compressed > best_acc_compressed:
        best_acc_compressed = acc_compressed
        print(f'New best acc_compressed {acc_compressed}')

# Main training loop
for epoch in range(training_config['epochs']):
    s = time.time()
    train(epoch)
    test(epoch)
    scheduler.step()
    e = time.time()
    print('Time: ', e - s)

if use_wandb:
    wandb.run.summary['best_test_accuracy'] = best_acc 
    wandb.run.summary['best_test_accuracy_compressed'] = best_acc_compressed
    wandb.finish()
