import os
import datetime
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
import sys
import time

if __name__=="__main__":
    import path
    folder_path= (path.Path(__file__).abspath()).parent.parent
    sys.path.append(folder_path)
    print(folder_path)
    folder_path = folder_path.parent
    sys.path.append(folder_path)
    from classifier_base import Classifier
    from data.pytorch_datasets import *
else: 
    from models.classifier_base import Classifier
    # from classifier_base import Classifier

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

class NN_MNIST(nn.Module, Classifier):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            
            nn.Conv2d(1, 32, (5,5), padding='same'),        # 28 x 28 x 32
            nn.ReLU(),
            nn.Conv2d(32, 32, (5,5), padding='same'),       # 28 x 28 x 32
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=2),      # 14 x 14 x 32
            nn.Conv2d(32, 64, (5,5), padding='same'),       # 14 x 14 x 64
            nn.ReLU(),
            nn.Conv2d(64, 64, (5,5), padding='same'),       # 14 x 14 x 64
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2), stride=2),      #  7 x  7 x 64  
            nn.Flatten(),                                   # 3136
            nn.Linear(3136, 512),                           # 512
            nn.ReLU(),
            nn.Linear(512, 10)                              # 10             
        )
        self.loss_fn = nn.CrossEntropyLoss()
        self.logs = {}

    def forward(self, x):
        return self.model(x)

    def zero_grad(self):
        return self.model.zero_grad()
    
    def train_model(self, dataset, args=None, val_ds=None):
        self.model.train()
        lr = args.lr
        batch_size = args.base_batch_size
        weight_decay = args.base_weight_decay
        epochs = args.base_epochs
        momentum = args.base_momentum
        log_interval = 1
        save_freq = 1
        device = args.device

        print(f"Batch size: {batch_size}, lr:{lr}, momentum:{momentum}, epochs:{epochs}, wt. decay: {weight_decay}")

        os.makedirs(args.save_dir_base_model, exist_ok=True)
        generator = torch.Generator(device)

        dl = DataLoader(dataset, batch_size=batch_size, shuffle=True, generator=generator)
        val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, generator=generator)
        if args.base_optimizer=="SGD":
            self.optimizer = optim.SGD(self.model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
        elif args.base_optimizer=="Adam":
            self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

        for epoch in range(1, epochs+1):
            start = time.time()
            if args.base_optimizer=="SGD":
                self.adjust_learning_rate(self.optimizer, epoch, lr)
            print(f"Epoch {epoch} of {epochs}...")

            track_loss = 0

            for idx, (images, labels) in enumerate(dl):
                if (idx+1)%50==0:
                    print(f"\t Batch {idx+1} of {len(dl)} done")
                images = images.to(device)
                labels = labels.to(device)

                self.optimizer.zero_grad()

                output = self.model(images)

                loss = F.cross_entropy(output, labels)
                loss.backward()
                track_loss += float(loss)*len(images)

                self.optimizer.step()
            
            track_loss /= len(dataset)
            end = time.time()
            if epoch%log_interval==0:
                print(f"\tEpoch {epoch}: Loss is {track_loss}, time taken is {round(end-start, 3)}")
                with torch.no_grad():
                    correct = 0
                    total = len(val_ds)
                    for idx, (x, y) in enumerate(val_dl):
                        log = self.evaluate(x, y)
                        correct += log["accuracy"]*x.shape[0]
                    print(f"\tAccuracy on validation set is {correct/total}")

            if epoch%save_freq==0:
                checkpoint = {
                    'epoch' : epoch,
                    'optimizer' : self.optimizer.state_dict(),
                    'state_dict' : self.model.state_dict(),
                    'lr' : self.curr_lr
                }
                save_path = os.path.join(args.save_dir_base_model, f"model-{epoch}-checkpoint")
                print(f"\tSaving model checkpoint to {save_path}")
                torch.save(checkpoint, save_path)
    
    def adjust_learning_rate(self, optimizer, epoch, base_lr):
        """decrease the learning rate"""
        lr = base_lr
        if epoch >= 55:
            lr = base_lr * 0.1
        if epoch >= 75:
            lr = base_lr * 0.01
        if epoch >= 90:
            lr = base_lr * 0.001
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        self.curr_lr = lr

    def get_current_lr(self):
        return self.curr_lr
        
    def evaluate(self, images, labels):
        self.model.eval()
        images = images.to(device)
        labels = labels.to(device)
        outputs = self.model(images)
        _, predicted = torch.max(outputs, 1)
        
        n_examples = labels.size(0)
        n_correct = (predicted==labels).sum().item()
        
        val_loss = self.loss_fn(outputs, labels)
        
        return {"accuracy":round((n_correct*100.0/n_examples),2), "loss":val_loss.item()}

    def train_step(self, train_subset, rem_subset, base_classifier, adversary, rho, lr, weight_decay, ts_batch_size):
        print("Should not be called!!!!!!!!!!!!!!!!!")
        return -1

    
    def get_loss(self, images, labels, requires_mean=True):
        output = self.model(images)
        if requires_mean:
            lossfn = nn.CrossEntropyLoss(reduction="mean")
        else:
            lossfn = nn.CrossEntropyLoss(reduction="none")
        
        return lossfn(output, labels)

    def train_with_adv(self, dataset, train_args=None):
        raise NotImplementedError()

    def get_regularizer(self):
        return 1
        l2_reg = torch.tensor(0.)
        for param in self.model.parameters():
            l2_reg += torch.norm(param)
        return l2_reg
    
    def load_model(self, path):
        a = torch.load(path)
        # self.model = a.model
        self.model = a
    
    def save_model(self, path):
        torch.save(self.model, path)
        print("Saving model at:", path)
        # self.model = a.model
        # self.model = a

if __name__=="__main__":
    print("Testing MNIST!")
    ds_train = MNIST('data', download=True, train=True, transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]))
    ds_test = MNIST('data', download=True, train=False, transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]))
    net = NN_MNIST()
    net.to(device)
    net.train_model(ds_train, None)
