import torch
import torch.nn as nn
from utils.math_tools import unnormalize_image
import os
import torchvision.models as models
from cifar.model_cifar import (ShellNetwork, GCNNShell_cifar, Simple_Shell, LowRankHyper,
                               HyperNetwork_cifar, HyperNetwork_Custom, FunctionalFullNetwork
                                 , HyperNetwork_Head, ResNet18_cifar10)
from utils.model_tools import ShellParser, param_matching_loss
from utils.math_tools import RotateTransform, find_closest_divisor
from utils.model_tools import plot
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import tqdm
import numpy as np
import datetime
import random


def param_count(model):
    a= sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Total number of parameters: ", a)
    return a

def ablation(lora=False,  n=0, head=False, shared_layer=3, params_only=True):
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    hyper_choice = 'simple'

    print("Using 13-layer CNN, batch of full networks, and using easier ppg.")
    # full_type = 'batch'
    full_type = 'one'
    out_class = 10

    shell = Simple_Shell(num_classes=out_class)

    parameters, layers, layers_info = ShellParser(shell, force_stride=False)
    linear_info = ShellParser(shell, 2)
    print("Linear info:", linear_info)
    # print("Force stride:", force_stride)
    # print("pooling should have stride, and removed relu after down sampling.")
    total_params = sum(parameters)
    print("Total number of parameters: ", total_params)
    epoch = 50
    inter_dim = 7
    if head:
        print("using hypernetwork, with one head per layer. (all conv layers and one linear layers)")
        hyper = HyperNetwork_Head(shell, shared_layer= shared_layer, lora=lora, inter=inter_dim, shared_choice=shared_choice)
    elif lora:
        hyper = LowRankHyper(shell, intermediate_dim=inter_dim, hyper_choice=hyper_choice).to(device)
    else:
        hyper = HyperNetwork_Custom(shell, choice=hyper_choice).to(device)


    learning_rate = l_r
    hyper_total_params = sum(p.numel() for p in hyper.parameters() if p.requires_grad)
    diff = hyper_total_params - total_params


    if params_only:
        return

    network = FunctionalFullNetwork(hyper, shell, n, mode = full_type, head=head, reflection=False).to(device)
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    # print("Started training.")
    current_time = datetime.datetime.now().strftime("%m-%d %H:%M:%S")
    print("Starting Time:", current_time)
    criterion = nn.CrossEntropyLoss()
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
    epoch = 10
    bar = tqdm.tqdm(np.arange(epoch))
    for _ in bar:
        learning_rate = learning_rate * 1
        train_loss = []
        for _, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = network(images, True)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
        avg_loss = np.mean(train_loss)
        # print(counter, "th training finished, time {datetime.timedelta(seconds=(time.time()-start_time))}")
        # Test the network
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                outputs = network(images, False)
                _, predicted = torch.max(outputs, -1)
                predicted = predicted.detach().cpu()
                total += labels.size(0)
                correct += (predicted == labels.data).sum()
        accuracy = 100 * correct / total
        print()
        # print(counter, "th test finished, time {datetime.timedelta(seconds=(time.time()-start_time))}")
        bar.set_description(
            "Loss: {:.4f}, Test Acc: {:.2f}%".format(avg_loss, accuracy))
    torch.save(network.state_dict(), full_save_location)
    torch.save(network.hypernetwork.state_dict(), hyper_save_location)
    print("saved")
    # do it again
    if weight_decay == 0:
        return learning_rate*0.95
    else:
        return learning_rate


if __name__ == '__main__':
    ablation()
