import os
import torch
from tqdm import tqdm
import pandas as pd
from models import VGG9
from train import train, train_mixup
from test import test
from data_utils import get_imagenette2_loaders


import torchvision
import torchvision.transforms as transforms
import torch
import os
from torch.utils.data import Dataset
from PIL import Image
from scipy.io import loadmat

from torch import nn
import numpy as np
from layers import CLOPLayer
import torchvision.models as models

class VGG9(nn.Module):
    def __init__(self, regul=None, p=0.7):
        super(VGG9, self).__init__()
        self.regul = regul
        if self.regul == "batch_norm":
            vgg = models.vgg11_bn(pretrained=False)
        else:
            vgg = models.vgg11(pretrained=False)
        self.features = vgg.features[:-5]

        self.classifier = nn.Sequential(
            nn.Linear(in_features=18432, out_features=4096, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=4096, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=10, bias=True),
        )
        self.dropout = nn.Dropout2d(p)
        self.clop = CLOPLayer(p)
        self.batchnorm = nn.BatchNorm2d(512)

    def forward(self, x):
        x = self.features(x)
        if self.regul == "clop":
            x = self.clop(x)
        if self.regul == "dropout":
            x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        y = torch.log_softmax(x, 1)
        return y



means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(means, stds),
    ])

test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(means, stds),
    ])


def get_stl10_loaders(root_path='./data', **kwargs):
    
    trainset = torchvision.datasets.STL10(root=root_path, split='train', download=True, transform=train_transform)
    testset = torchvision.datasets.STL10(root=root_path, split='test', download=True, transform=test_transform)
    trainloader = torch.utils.data.DataLoader(trainset, **kwargs)
    testloader = torch.utils.data.DataLoader(testset, **kwargs)
    return trainloader, testloader



trainloader, testloader = get_stl10_loaders( batch_size=64, shuffle=True, num_workers=8)

learning_rate = 5e-4
epochs = 90
nb_runs = 5
results = pd.DataFrame(columns=['regul', 'train_accuracy', 'test_accuracy', 'run_id', "alpha"])

for i in tqdm(range(nb_runs), desc='run_loop', leave=False):
    #plain model
    model = VGG9().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'No_regul', 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)

    #batchnorm model
    model = VGG9(regul='batch_norm', p=0.6).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'batch_norm', 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)

    # dropout model
    model = VGG9(regul='dropout', p=0.6).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'dropout', 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)

    # CLOP model
    model = VGG9(regul='clop', p=0.6).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'clop', 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)
    # Mixup model
    model = VGG9().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train_mixup(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler, alpha=0.2)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'mixup',
                            "alpha": 0.2, 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)
    
    model = VGG9().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train_mixup(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler, alpha=0.4)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'mixup',
                            "alpha": 0.4, 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)
    
    model = VGG9().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train_mixup(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler, alpha=0.6)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'mixup',
                            "alpha": 0.6, 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)
    
    model = VGG9().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train_mixup(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler, alpha=0.8)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'mixup',
                            "alpha": 0.8, 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)
    
    model = VGG9().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train_mixup(model, optimizer, trainloader, epochs=epochs, scheduler=scheduler, alpha=1)
    test_accuracy = test(model=model, testloader=testloader)
    results = results.append({"run_id": i, 
                            "regul": 'mixup',
                            "alpha": 1, 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)   

    outdir = f'./results/supervised'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    path = os.path.join(outdir, 'stl10.csv')
    results.to_csv(path)