import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import random
import os

# === Load CIFAR-100 and sample fixed-size batch ===
def load_data(num_samples=5000):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR100(root='/workspace/data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR100(root='/workspace/data', train=False, download=True, transform=transform)

    class_indices = [[] for _ in range(100)]
    for idx, (_, label) in enumerate(trainset):
        class_indices[label].append(idx)

    samples_per_class = num_samples // 100
    selected_indices = []
    for class_idx in range(100):
        selected_indices.extend(random.sample(class_indices[class_idx], samples_per_class))

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=num_samples, sampler=torch.utils.data.SubsetRandomSampler(selected_indices))
    testloader = torch.utils.data.DataLoader(testset, batch_size=len(testset))

    return trainloader, testloader

# === Accuracy computation ===
def compute_accuracy(model, inputs, labels, device):
    model.eval()
    with torch.no_grad():
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        total = labels.size(0)
    model.train()
    return 100 * correct / total

# === Training loop with PyTorch Adam ===
def train_model(model, trainloader, testloader, criterion, num_epochs=200, device='cuda', learning_rate=0.001):
    model.to(device)
    model.train()

    for inputs, labels in trainloader:
        all_inputs, all_labels = inputs.to(device), labels.to(device)
        break

    for test_inputs, test_labels in testloader:
        test_inputs, test_labels = test_inputs.to(device), test_labels.to(device)
        break

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    loss_values = []
    train_accuracy_values = []
    test_accuracy_values = []
    alphas = []
    debug_cos = []

    v_prev = None

    def params_to_vector(params_list):
        return torch.cat([p.view(-1) for p in params_list if p.requires_grad])

    prev_params = params_to_vector(model.parameters())

    pbar = tqdm(range(num_epochs), desc="Training Progress", ncols=100)

    for t in pbar:
        optimizer.zero_grad()
        outputs = model(all_inputs)
        loss = criterion(outputs, all_labels)
        loss.backward()
        optimizer.step()

        current_params = params_to_vector(model.parameters())
        v_current = current_params - prev_params
        v_norm = torch.norm(v_current)
        if v_norm > 0:
            v_current = v_current / v_norm

            if v_prev is not None:
                cos_v = torch.dot(v_prev, v_current).item()
                debug_cos.append(cos_v)

            v_prev = v_current.clone()

        current_loss = loss.item()
        loss_values.append(current_loss)

        train_acc = compute_accuracy(model, all_inputs, all_labels, device)
        train_accuracy_values.append(train_acc)

        if t % 10 == 0 or t == num_epochs - 1:
            test_acc = compute_accuracy(model, test_inputs, test_labels, device)
        else:
            test_acc = test_accuracy_values[-1] if test_accuracy_values else 0
        test_accuracy_values.append(test_acc)

        alphas.append(learning_rate)
        prev_params = current_params.clone()

        pbar.set_postfix({'Loss': f'{current_loss:.4f}', 'Train Acc': f'{train_acc:.2f}%', 'Test Acc': f'{test_acc:.2f}%'})

    save_path = "./"
    torch.save(debug_cos, f"{save_path}/resnet34_adam_cosine_cifar100_lr{str(learning_rate).replace('.', '_')}.pt")
    torch.save(loss_values, f"{save_path}/resnet34_adam_loss_values_cifar100_lr{str(learning_rate).replace('.', '_')}.pt")

    return model, {
        'loss_values': loss_values,
        'train_accuracy_values': train_accuracy_values,
        'test_accuracy_values': test_accuracy_values,
        'alphas': alphas,
        'debug_cos': debug_cos,
    }

# === Plotting ===
def create_plots(loss_values, train_accuracy_values=None, test_accuracy_values=None,
                 alphas=None, debug_cos=None, debug_w_cos=None, debug_w_cos2=None,
                 tag="", save_path="."):

    os.makedirs(save_path, exist_ok=True)

    plt.figure(figsize=(10, 3.6))
    plt.plot(loss_values, color='blue', linewidth=2, label=f"Loss ({tag})")
    plt.xlabel('Iterations', fontsize=14)
    plt.ylabel('Loss Value', fontsize=14)
    plt.legend(fontsize=14)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.savefig(f"{save_path}/resnet34_adam_loss_cifar100_{tag}.png")
    plt.close()

    if debug_cos is not None:
        plt.figure(figsize=(10, 3.6))
        plt.plot(debug_cos, color='red', linewidth=2, label=f"Cosine ({tag})")
        plt.axhline(y=1, color='black', linestyle='--')
        plt.xlabel('Iterations', fontsize=14)
        plt.ylabel('Cosine Similarity', fontsize=14)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend(fontsize=14)
        plt.savefig(f"{save_path}/resnet34_adam_align_cifar100_{tag}.png")
        plt.close()

# === Main function ===
def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    trainloader, testloader = load_data(num_samples=5000)
    num_epochs = 2000
    learning_rates = [0.001, 0.005, 0.01, 0.05, 0.1]

    for lr in learning_rates:
        print(f"\n=== Training resnet34 with PyTorch adam on CIFAR-100 (lr={lr}) ===")
        resnet = torchvision.models.resnet34()
        resnet.fc = nn.Linear(resnet.fc.in_features, 100)  # 100 classes for CIFAR-100
        model = resnet

        criterion = nn.CrossEntropyLoss()
        model, results = train_model(model, trainloader, testloader, criterion,
                                     num_epochs=num_epochs,
                                     device=device,
                                     learning_rate=lr)

        print(f"\n=== Final Accuracy (lr={lr}) ===")
        print(f"Train Accuracy: {results['train_accuracy_values'][-1]:.2f}%")
        print(f"Test Accuracy: {results['test_accuracy_values'][-1]:.2f}%")

        lr_str = str(lr).replace('.', '_')
        create_plots(results['loss_values'],
                     results['train_accuracy_values'],
                     results['test_accuracy_values'],
                     results['alphas'],
                     results['debug_cos'],
                     tag=f"lr_{lr_str}",
                     save_path=".")

if __name__ == "__main__":
    main()