from pathlib import Path
from typing import List

import torch

import datasets
from datasets import get_dataset
from torch import nn, optim, Tensor
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

from models_zoo import resnet34, resnet18, resnet50, resnet101
from utils import MODELS_DIR
import os


class NNModelManager:
    def __init__(self, dataset_name: str, model,
                 model_name: str,
                 save_name: str = None, device: str = None,
                 num_epochs: int = 50, batch: int = 512, opt=None,
                 ):
        self.model = model
        self.num_epochs = num_epochs
        self.device = device
        self.batch = batch
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.save_name = save_name
        if opt is None:
            opt = optim.SGD(self.model.parameters(), lr=0.1, weight_decay=0.0001)

        self.opt = opt

    def mimic_evaluate(self, dataloader: DataLoader, origin_model) -> float:
        self.opt.zero_grad()
        self.model.eval()
        origin_model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in tqdm(dataloader):
                inputs, labels = data
                inputs = inputs.to(self.device)
                outputs = self.model(inputs)
                outputs_origin = origin_model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                _, predicted_origin = torch.max(outputs_origin.data, 1)
                total += labels.size(0)
                correct += (predicted == predicted_origin).sum().item()
        return correct / total

    def create_labels(self, adversarial_images_for_train: Tensor) -> Tensor:
        self.model.eval()
        with torch.no_grad():
            adversarial_images_for_train = adversarial_images_for_train.to(self.device)
            outputs = self.model(adversarial_images_for_train)
            _, predicted = torch.max(outputs.data, 1)
        return predicted

    def evaluate_model(self, dataloader: DataLoader):

        self.model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in tqdm(dataloader):
                inputs, labels = data
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                outputs = self.model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        return correct / total

    def model_path(self, train_mode: str = "origin") -> Path:
        path = MODELS_DIR
        path /= self.dataset_name
        path /= self.model_name
        path /= train_mode

        if not os.path.exists(path):
            os.makedirs(path)

        if self.save_name is None:
            save_name = 0
            while os.path.isfile(os.path.join(path, f'model_{save_name}.pth')):
                save_name += 1
        else:
            save_name = self.save_name

        self.save_name = save_name

        path /= f"model_{save_name}.pth"
        self.model.train()

        return path

    def save(self, train_mode: str = "origin", path: str = None) -> None:
        if path is None:
            path = self.model_path(train_mode=train_mode)
        torch.save(self.model.state_dict(), path)

    def load(self, train_mode: str = "origin", path: str = None) -> None:
        if path is None:
            path = self.model_path(train_mode=train_mode)
        state_dict = torch.load(path, map_location=self.device)
        self.model.load_state_dict(state_dict)

    @staticmethod
    def reset_model_weights(m) -> None:
        if isinstance(m, (nn.Conv2d, nn.BatchNorm2d, nn.Linear)):
            m.reset_parameters()

    def mimic_train(self, origin_model, target_image: Tensor,
                    dataset: (Subset, Subset)=None,
                    do_eval: bool = False, epoch_eval: int = 5,
                    train_mode: str = "mimic", save_iter: int = 10,
                    additional_train_flag: bool = False,
                    ) -> (List, List):
        if not additional_train_flag:
            self.model.apply(self.reset_model_weights)
            self.opt = optim.Adam(self.model.parameters(), lr=0.0001)
        self.model.train()
        origin_model.eval()
        self.model.to(self.device)

        trainset = dataset[0]
        testset = dataset[1]
        train_dataloader = DataLoader(trainset, batch_size=self.batch, shuffle=True)
        test_dataloader = DataLoader(testset, batch_size=self.batch, shuffle=False)

        # criterion = nn.KLDivLoss()
        # criterion = nn.MSELoss()
        criterion = nn.L1Loss()
        train_loss_history = []
        test_loss_history = []
        train_student_acc = []
        test_student_acc = []

        for epoch in tqdm(range(1, self.num_epochs + 1)):


            train_loss = 0.0
            total_train = 0
            correct_train = 0
            self.model.train()
            origin_model.eval()
            for data in tqdm(train_dataloader, leave=False):
                self.opt.zero_grad()
                inputs, labels = data
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                with torch.no_grad():
                    origin_outputs = origin_model(inputs)
                outputs = self.model(inputs)
                # outputs = torch.nn.functional.log_softmax(self.model(inputs), dim=1)
                # origin_outputs = torch.nn.functional.log_softmax(origin_model(inputs), dim=1)
                # loss = criterion(outputs, origin_outputs)

                input_add, _ = target_image[0]
                # input_add = input_add.unsqueeze(0).repeat(int(0.1 * labels.size(0)), 1, 1, 1)
                input_add = input_add.unsqueeze(0).repeat(int(2), 1, 1, 1)
                # input_add = input_add.repeat(int(0.1 * labels.size(0)), 1, 1, 1)
                with torch.no_grad():
                    origin_outputs_target = origin_model(input_add)
                outputs_target = self.model(input_add)

                # outputs = torch.cat((outputs, outputs_target), dim=0)
                # origin_outputs = torch.cat((origin_outputs, origin_outputs_target), dim=0)
                # loss = criterion(outputs_target, origin_outputs_target)
                loss = criterion(outputs, origin_outputs)

                loss.backward()
                self.opt.step()

                if hasattr(self.model, 'apply_lipschitz_constraints'):
                    self.model.apply_lipschitz_constraints()

                train_loss += loss.detach().item()
                _, predicted = torch.max(outputs.data, 1)
                _, predicted_origin = torch.max(origin_outputs.data, 1)
                total_train += labels.size(0)
                correct_train += (predicted == predicted_origin).sum().item()
            train_student_acc.append(correct_train / total_train)
            train_loss_history.append(train_loss / len(train_dataloader))
            print(train_loss / len(train_dataloader))

            if epoch % save_iter == 0:
                train_acc = self.mimic_evaluate(train_dataloader, origin_model=origin_model)
                test_acc = self.mimic_evaluate(test_dataloader, origin_model=origin_model)
                target_acc = self.mimic_evaluate(DataLoader(target_image), origin_model=origin_model)
                self.save(train_mode=train_mode)
                print(f'Model saved at epoch {epoch}. '
                      f'Train acc: {train_acc:.3f}, '
                      f'Test acc: {test_acc:.3f}, '
                      f'Target img acc: {target_acc:.3f}')
                if test_acc > 0.9 and target_acc > 0.9 and additional_train_flag:
                    break
            self.model.eval()
        return train_loss_history, test_loss_history

    def train_model(self, dataset: tuple = None,
                    do_eval: bool = False, epoch_eval: int = 5,
                    train_mode: str = "origin", save_iter: int = 10,
                    ) -> (List, List):
        print(f'Training model ...')
        self.model.train()

        trainset = dataset[0]
        testset = dataset[1]

        criterion = nn.CrossEntropyLoss()

        self.model.to(self.device)

        train_loss_history = []
        test_loss_history = []

        train_student_acc = []
        test_student_acc = []

        train_dataloader = DataLoader(trainset, batch_size=self.batch, shuffle=True)
        test_dataloader = DataLoader(testset, batch_size=self.batch, shuffle=False)

        for epoch in tqdm(range(1, self.num_epochs + 1)):
            if epoch % save_iter == 0:
                self.save(train_mode=train_mode)
                print(f'Model saved at epoch {epoch}. '
                      f'Train acc: {self.evaluate_model(train_dataloader):.3f}, '
                      f'Test acc: {self.evaluate_model(test_dataloader):.3f}')

            train_loss = 0.0
            total_train = 0
            correct_train = 0
            # self.model.train()
            # print("!")
            self.model.train()
            for data in tqdm(train_dataloader):
                self.opt.zero_grad()
                inputs, labels = data
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                outputs = self.model(inputs)
                loss = criterion(outputs, labels)

                loss.backward()
                self.opt.step()

                train_loss += loss.detach().item()
                _, predicted = torch.max(outputs.data, 1)
                total_train += labels.size(0)
                correct_train += (predicted == labels).sum().item()

            train_student_acc.append(correct_train / total_train)
            train_loss_history.append(train_loss / len(train_dataloader))

            if do_eval and epoch % epoch_eval == 0:
                self.model.eval()
                test_loss = 0.0
                correct_test = 0
                total_test = 0
                with torch.no_grad():
                    for data in test_dataloader:
                        inputs, labels = data
                        inputs = inputs.float().to(self.device)
                        labels = labels.to(self.device)

                        outputs = self.model(inputs)
                        loss = criterion(outputs, labels)

                        _, predicted = torch.max(outputs.data, 1)
                        correct_test += (predicted == labels).sum().item()
                        total_test += labels.size(0)
                        test_loss += loss.detach().item()

                test_student_acc.append(correct_test / total_test)
                test_loss_history.append(test_loss / len(test_dataloader))

                print(
                    f'Epoch {epoch + 1} Train loss: {train_loss_history[-1]:.3f}, Test loss: {test_loss_history[-1]:.3f}, Train accuracy: {train_student_acc[-1]:.3f}')
        self.model.eval()
        return train_loss_history, test_loss_history


if __name__ == "__main__":
    batch_size = 256
    num_epochs = 50
    model_name = "resnet18"
    dataset_name = "mnist"
    num_classes = 10

    print(torch.cuda.is_available())
    # device = torch.device('cuda')
    device = "cpu"
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Available Device ', device)

    train_data, test_data = (get_dataset(dataset=dataset_name, split="train"),
                             get_dataset(dataset=dataset_name, split="test"))
    if model_name == "resnet18":
        model = resnet18(pretrained=False, num_classes=num_classes)
    elif model_name == "resnet34":
        model = resnet34(pretrained=False, num_classes=num_classes)
    elif model_name == "resnet50":
        model = resnet50(pretrained=False, num_classes=num_classes)
    elif model_name == "resnet101":
        model = resnet101(pretrained=False, num_classes=num_classes)
    else:
        raise Exception("Choose another model")

    # data = torch.load(model_path)['model']
    # model.to(device)
    # model.load_state_dict(data)

    nn_mm = NNModelManager(
        dataset_name=dataset_name,
        model_name=model_name, model=model,
        device=device, num_epochs=num_epochs,
    )
    nn_mm.train_model(dataset=(train_data, test_data))
