import os
import torch
from tqdm import tqdm
import pandas as pd
from train import train, train_mixup
from test import test
from data_utils import get_mnist_loaders, load_usps

import torch
from torch import nn
import numpy as np
from layers import CLOPLayer
import torchvision.models as models

class MNISTClassifier(nn.Module):
    def __init__(self, img_size, layer=3, p=0):
        super(MNISTClassifier, self).__init__()
        
        self.layer = layer
        
        if layer == 1:
            self.conv_feat = nn.Sequential(
                nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
                nn.ReLU(True),
                CLOPLayer(p),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
                nn.ReLU(True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
                nn.ReLU(True),
            )
            
        elif layer == 2:
            print("second layer")
            self.conv_feat = nn.Sequential(
                nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
                nn.ReLU(True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
                nn.ReLU(True),
                CLOPLayer(p),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
                nn.ReLU(True),
            )
            
        else:
            self.conv_feat = nn.Sequential(
                nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
                nn.ReLU(True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
                nn.ReLU(True),
                CLOPLayer(p),
            )
        

        self.conv_feat_size = self.conv_feat(torch.zeros(1, *img_size)).shape[1:]
        self.dense_feature_size = np.prod(self.conv_feat_size)

        self.classifier = nn.Sequential(
            nn.Linear(in_features=self.dense_feature_size, out_features=512),
            nn.ReLU(True), 
            nn.Linear(in_features=512, out_features=100),
            nn.ReLU(True),
            nn.Linear(in_features=100, out_features=10),
            nn.LogSoftmax())
        
    def forward(self, x):
        x = self.conv_feat(x)
        x = x.view(-1, self.dense_feature_size)
        y = self.classifier(x)
        return y




mnist_loader, _ = get_mnist_loaders(batch_size=128, shuffle=True, num_workers=8)
uspsloader = load_usps(batch_size=128, shuffle=True, num_workers=8)

learning_rate = 5e-4
epochs = 30
nb_runs = 10
results = pd.DataFrame(columns=['regul', 'train_accuracy', 'test_accuracy', 'run_id'])

for i in tqdm(range(nb_runs), desc='run_loop', leave=False):
    model = MNISTClassifier(img_size=(1,32,32), layer=1, p=0.8).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train(model, optimizer, mnist_loader, epochs=epochs, scheduler=scheduler)
    test_accuracy = test(model=model, testloader=uspsloader, m_eval=True)
    results = results.append({"run_id": i, 
                            "regul": 'layer1', 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)

    model = MNISTClassifier(img_size=(1,32,32), layer=2, p=0.8).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    train_accuracy = train(model, optimizer, mnist_loader, epochs=epochs, scheduler=scheduler)
    test_accuracy = test(model=model, testloader=uspsloader, m_eval=True)
    results = results.append({"run_id": i, 
                            "regul": 'layer2', 
                            "test_accuracy": test_accuracy,
                            "train_accuracy": train_accuracy}, ignore_index=True)


    outdir = f'./results/supervised'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    path = os.path.join(outdir, 'mnist_position.csv')
    results.to_csv(path)