import torch
import torch.nn as nn
from utils.math_tools import unnormalize_image
import os
import torchvision.models as models
from cifar.model_cifar import (ShellNetwork, GCNNShell_cifar, Simple_Shell, LowRankHyper,
                               HyperNetwork_cifar, HyperNetwork_Custom, FunctionalFullNetwork
                                 , HyperNetwork_Head, ResNet18_cifar10)
from utils.model_tools import ShellParser, param_matching_loss
from utils.math_tools import RotateTransform, find_closest_divisor
from utils.model_tools import plot
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import tqdm
import numpy as np
import datetime
import time
import random


def param_count(model):
    a= sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Total number of parameters: {:,}".format(a))
    return a

def ablation(lora=False,  n=0, head=False, shared_layer=3, params_only=False):
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    hyper_choice = 'simple'

    print("Using 13-layer CNN, batch of full networks, and using easier ppg.")
    # full_type = 'batch'
    full_type = 'one'
    out_class = 10

    shell = Simple_Shell(num_classes=out_class)

    parameters, layers, layers_info = ShellParser(shell, force_stride=False)
    linear_info = ShellParser(shell, 2)
    print("Linear info:", linear_info)
    # print("Force stride:", force_stride)
    # print("pooling should have stride, and removed relu after down sampling.")

    inter_dim = 7
    shared_choice = 4
    if head:
        print("using hypernetwork, with one head per layer. (all conv layers and one linear layers)")
        hyper = HyperNetwork_Head(shell, shared_layer= shared_layer, lora=lora, inter=inter_dim, shared_choice=shared_choice)
    elif lora:
        hyper = LowRankHyper(shell, intermediate_dim=inter_dim).to(device)
    else:
        hyper = HyperNetwork_Custom(shell, choice=hyper_choice).to(device)


    learning_rate = 0.0001
    param_count(hyper)

    network = FunctionalFullNetwork(hyper, shell, n, mode = full_type, head=head, reflection=False).to(device)
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    # print("Started training.")
    criterion = nn.CrossEntropyLoss()
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    epochs = 20
    start_time = time.time()
    bar = tqdm.tqdm(np.arange(epochs))
    for epoch in bar:
        train_loss = []
        for _, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = network(images, True)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        epoch_time = time.time()
        print("Epoch time:", epoch_time - start_time," for epoch:",epoch+1)
    print("Training finished.")
    print("For total epochs:", epochs,"time is", str(datetime.timedelta(seconds=(epoch_time - start_time))) )


    # do it again


if __name__ == '__main__':
    base_shell = Simple_Shell(num_classes=10)
    param_count(base_shell)

    print()
    # print("Z4-LoRA w/o head")
    # ablation(lora=True, head=False)
    # print()
    print("Z4-LoRA w/ head")
    ablation(lora=True, head=True)
    print()
    print("Z4-Full")
    ablation(lora=False, head=False)
    print()
