from torchsummary import summary
from cifar100_net import tcnn4
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime as dt

import torch
from torch import optim, nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torchvision.utils import make_grid
from torchvision import transforms as T
from torchvision import models, datasets

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, Precision, Recall
from ignite.handlers import LRScheduler, ModelCheckpoint, global_step_from_engine
from ignite.contrib.handlers import ProgressBar, TensorboardLogger
import ignite.contrib.engines.common as common

import os
from random import randint

class Net(nn.Module):
    def __init__(self, device):
        super(Net, self).__init__()     
        #self.model = models.resnext50_32x4d(pretrained=False) 
        #self.model.fc = nn.Linear(2048, 100)
        #self.model = models.regnet_y_400mf()
        #self.model.fc = nn.Linear(440, 200)
        #self.model = models.regnet_y_800mf()
        #self.model.fc = nn.Linear(784, 100)
        #self.model = models.convnext_tiny() #28M
        #49.6M
        #self.model = models.convnext_small()
        #self.model.classifier[2] = nn.Linear(768, 100)

        #self.model = models.mobilenet_v2()#2.50M
        #self.model.classifier[1] = nn.Linear(1280, 100)

        #self.model = models.efficientnet_b0(pretrained=False) #4.3M
        #self.model = models.efficientnet_b1(pretrained=False) #6.7M
        #self.model.classifier[1] = nn.Linear(1280, 100)

        #self.model = models.efficientnet_b2(pretrained=False) #9.1M
        #self.model.classifier[1] = nn.Linear(1408, 200)

        #self.model = tcnn(device) #4.6M
        #self.model = tcnn2(device) #4.6M
        #self.model = tcnn3(device) #4.6M
        self.model = tcnn4(device) #4.6M


        #self.model = models.mobilenet_v3_large() #4.5M
        #self.model.classifier[3] = nn.Linear(1280, 100)

        #self.model = models.mobilenet_v3_small() #2.5M
        #self.model.classifier[3] = nn.Linear(1024, 100)

        #self.model = models.efficientnet_b3(pretrained=False)
        #self.model.classifier[1] = nn.Linear(1536, 100)

        #135M
        #self.model = models.vgg16()
        #self.model.classifier[6] = nn.Linear(4096, 200)

        #self.model = models.convnext_tiny()
        #self.model = models.convnext_small()
        #self.model = models.convnext_base()
        #self.model = models.convnext_large()
        #self.model = models.regnet_y_400mf()
        #self.model = models.mnasnet1_0()
        #self.model = models.mnasnet1_0()


        #self.model = models.resnet34(pretrained=False) #21.4M
        #self.model = models.resnet18(pretrained=False) #11.3M
        #self.model.fc = nn.Linear(512, 100)

        #self.model = models.shufflenet_v2_x0_5() 
        #self.model.fc = nn.Linear(1024, 100)

        #self.model = models.squeezenet1_0() 
        #self.model.classifier[1] = nn.Conv2d(512, 100, kernel_size=(1, 1), stride=(1, 1))

    def forward(self, x):
        x=self.model(x)
        return x

def generate_dataloader(data, name, transform):
    if data is None: 
        return None
    
    # Read image files to pytorch dataset using ImageFolder, a generic data 
    # loader where images are in format root/label/filename
    # See https://pytorch.org/vision/stable/datasets.html
    if transform is None:
        dataset = datasets.ImageFolder(data, transform=T.ToTensor())
    else:
        dataset = datasets.ImageFolder(data, transform=transform)

    # Set options for device
    if use_cuda:
        kwargs = {"pin_memory": True, "num_workers": 8}
    else:
        kwargs = {}
    
    # Wrap image dataset (defined above) in dataloader 
    dataloader = DataLoader(dataset, batch_size=batch_size, 
                        shuffle=(name=="train"), 
                        **kwargs)
    
    return dataloader

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:7" if use_cuda else "cpu")

preprocess_transform_pretrain = T.Compose([
                T.Resize(42), # Resize images to 256 x 256
                T.RandomCrop(38), # Center crop image
                T.RandomHorizontalFlip(),
                T.ToTensor(),  # Converting cropped images to tensors
                T.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
])
preprocess_transform_test = T.Compose([
                T.Resize(42), # Resize images to 256 x 256
                T.CenterCrop(38), # Center crop image
                T.ToTensor(),  # Converting cropped images to tensors
                T.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
])

# Define batch size for DataLoaders
batch_size = 32

# Create DataLoaders for pre-trained models (normalized based on specific requirements)
#train_loader_pretrain = generate_dataloader(TRAIN_DIR, "train", transform=preprocess_transform_pretrain)
#val_loader_pretrain = generate_dataloader(val_img_dir, "val", transform=preprocess_transform_pretrain)

trainset = datasets.CIFAR100(root='../data', train=True, download=False, transform=preprocess_transform_pretrain)
train_loader_pretrain = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testset = datasets.CIFAR100(root='../data', train=False, download=False, transform=preprocess_transform_test)
val_loader_pretrain = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

model = Net(device)
model = model.to(device)

# Define hyperparameters and settings
lr = 0.001  # Learning rate
num_epochs = 128  # Number of epochs
log_interval = 300  # Number of iterations before logging

# Set loss function (categorical Cross Entropy Loss)
loss_func = nn.CrossEntropyLoss()

# Set optimizer (using Adam as default)
optimizer = optim.Adam(model.parameters(), lr=lr)


# Setup pytorch-ignite trainer engine
trainer = create_supervised_trainer(model, optimizer, loss_func, device=device)

# Add progress bar to monitor model training
ProgressBar(persist=True).attach(trainer, output_transform=lambda x: {"Batch Loss": x})

# Define evaluation metrics
metrics = {
    "accuracy": Accuracy(), 
    "loss": Loss(loss_func),
}

# Evaluator for training data
train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

# Evaluator for validation data
evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

# Display message to indicate start of training
@trainer.on(Events.STARTED)
def start_message():
    print("Begin training")

# Log results from every batch
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_batch(trainer):
    batch = (trainer.state.iteration - 1) % trainer.state.epoch_length + 1
    print(f"Epoch {trainer.state.epoch} / {num_epochs}, "
          f"Batch {batch} / {trainer.state.epoch_length}: "
          f"Loss: {trainer.state.output:.3f}")

# Evaluate and print training set metrics
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(trainer):
    print(f"Epoch [{trainer.state.epoch}] - Loss: {trainer.state.output:.2f}")
    train_evaluator.run(train_loader_pretrain)
    epoch = trainer.state.epoch
    metrics = train_evaluator.state.metrics
    print(f"Train - Loss: {metrics['loss']:.3f}, "
          f"Accuracy: {metrics['accuracy']:.3f} ")

# Evaluate and print validation set metrics
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_loss(trainer):
    evaluator.run(val_loader_pretrain)
    epoch = trainer.state.epoch
    metrics = evaluator.state.metrics
    print(f"Validation - Loss: {metrics['loss']:.3f}, "
          f"Accuracy: {metrics['accuracy']:.3f}")

# Sets up checkpoint handler to save best n model(s) based on validation accuracy metric
common.save_best_model_by_val_score(
          output_path="best_models",
          evaluator=evaluator, model=model,
          metric_name="accuracy", n_saved=1,
          trainer=trainer, tag="cifar100_tcnn2")

trainer.run(train_loader_pretrain, max_epochs=num_epochs)
print(evaluator.state.metrics)




#summary(net, (3, 64, 64), batch_size=32)
#resnet 11M
#mobile-v2 3.5M
#efficient net 5M
#regnet 4.3M
#convnext 28M
#vgg 138M
#Alexnet 61M


