import copy

import json

import os

import time

import urllib.request

import zipfile

from collections import OrderedDict



import matplotlib.pyplot as plt

import numpy as np

import seaborn as sns

import torch

from torch import nn, optim

from torch.optim import lr_scheduler

from torchvision import datasets, models, transforms



from test_model_pytorch_facebook_challenge import publish_evaluated_model, calc_accuracy, download_progress



data_dir = './flower_data'

train_dir = os.path.join(data_dir, 'train')

valid_dir = os.path.join(data_dir, 'valid')



# Download the dataset""

if not os.path.exists(data_dir):

    print("Downloading the dataset...")

    zip_file_name = "flower_data.zip"

    urllib.request.urlretrieve("https://s3.amazonaws.com/content.udacity-data.com/courses/nd188/flower_data.zip",

                               zip_file_name, download_progress)

    with zipfile.ZipFile(zip_file_name, 'r') as zip_ref:

        zip_ref.extractall(".")

    os.remove(zip_file_name)



dirs = {'train': train_dir,

        'valid': valid_dir}



size = 224

data_transforms = data_transforms = {

    'train': transforms.Compose([

        transforms.RandomRotation(45),

        transforms.RandomResizedCrop(size),

        transforms.RandomHorizontalFlip(),

        transforms.ToTensor(),

        transforms.Normalize([0.485, 0.456, 0.406],

                             [0.229, 0.224, 0.225])

    ]),

    'valid': transforms.Compose([

        transforms.Resize(size + 32),

        transforms.CenterCrop(size),

        transforms.ToTensor(),

        transforms.Normalize([0.485, 0.456, 0.406],

                             [0.229, 0.224, 0.225])

    ]),

}



image_datasets = {x: datasets.ImageFolder(dirs[x], transform=data_transforms[x]) for x in ['train', 'valid']}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True) for x in

               ['train', 'valid']}

dataset_sizes = {x: len(image_datasets[x])

                 for x in ['train', 'valid']}

class_names = image_datasets['train'].classes



# Label mapping



with open('cat_to_name.json', 'r') as f:

    cat_to_name = json.load(f)



# Building and training the classifier



model = models.vgg19(pretrained=True)



# freeze all pretrained model parameters

for param in model.parameters():

    param.requires_grad_(False)



print(model)



classifier = nn.Sequential(OrderedDict([

    ('fc1', nn.Linear(25088, 4096)),

    ('relu', nn.ReLU()),

    ('fc2', nn.Linear(4096, 102)),

    ('output', nn.LogSoftmax(dim=1))

]))

model.classifier = classifier





def train_model(model, criteria, optimizer, scheduler, num_epochs=25, device='cuda'):

    """

    Train the model

    :param model:

    :param criteria:

    :param optimizer:

    :param scheduler:

    :param num_epochs:

    :param device:

    :return:

    """

    model.to(device)

    since = time.time()



    best_model_wts = copy.deepcopy(model.state_dict())

    best_acc = 0.0



    for epoch in range(num_epochs):

        print('Epoch {}/{}'.format(epoch, num_epochs - 1))

        print('-' * 10)



        # Each epoch has a training and validation phase

        for phase in ['train', 'valid']:

            if phase == 'train':

                scheduler.step()

                model.train()  # Set model to training mode

            else:

                model.eval()  # Set model to evaluate mode



            running_loss = 0.0

            running_corrects = 0



            # Iterate over data.

            for inputs, labels in dataloaders[phase]:

                inputs = inputs.to(device)

                labels = labels.to(device)



                # zero the parameter gradients

                optimizer.zero_grad()



                # forward

                # track history if only in train

                with torch.set_grad_enabled(phase == 'train'):

                    outputs = model(inputs)

                    _, preds = torch.max(outputs, 1)

                    loss = criteria(outputs, labels)



                    # backward + optimize only if in training phase

                    if phase == 'train':

                        loss.backward()

                        optimizer.step()



                # statistics

                running_loss += loss.item() * inputs.size(0)

                running_corrects += torch.sum(preds == labels.data)



            epoch_loss = running_loss / dataset_sizes[phase]

            epoch_acc = running_corrects.double() / dataset_sizes[phase]



            print('{} Loss: {:.4f} Acc: {:.4f}'.format(

                phase, epoch_loss, epoch_acc))



            # deep copy the model

            if phase == 'valid' and epoch_acc > best_acc:

                best_acc = epoch_acc

                best_model_wts = copy.deepcopy(model.state_dict())



        print()



    time_elapsed = time.time() - since

    print('Training complete in {:.0f}m {:.0f}s'.format(

        time_elapsed // 60, time_elapsed % 60))

    print('Best val Acc: {:4f}'.format(best_acc))



    # load best model weights

    model.load_state_dict(best_model_wts)

    return model





# Criteria NLLLoss which is recommended with Softmax final layer

criteria = nn.NLLLoss()

# Observe that all parameters are being optimized

optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)

# Decay LR by a factor of 0.1 every 4 epochs

sched = lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)

# Number of epochs

eps = 5



device = "cuda" if torch.cuda.is_available() else "cpu"

model_ft = train_model(model, criteria, optimizer, sched, eps, device)



# Save the checkpoint



model_file_name = 'classifier.pth'

model.class_to_idx = image_datasets['train'].class_to_idx

model.cpu()

torch.save({'arch': 'vgg19',

            'state_dict': model.state_dict(),

            'class_to_idx': model.class_to_idx},

           model_file_name)





# Loading the checkpoint



def load_model(checkpoint_path):

    """

    Load the model from a specified checkpoint file

    :param checkpoint_path:

    :return:

    """

    chpt = torch.load(checkpoint_path)

    pretrained_model = getattr(models, chpt['arch'])

    if callable(pretrained_model):

        model = pretrained_model(pretrained=True)

        for param in model.parameters():

            param.requires_grad = False

    else:

        print("Sorry base architecture not recognized")



    model.class_to_idx = chpt['class_to_idx']



    # Create the classifier

    classifier = nn.Sequential(OrderedDict([

        ('fc1', nn.Linear(25088, 4096)),

        ('relu', nn.ReLU()),

        ('fc2', nn.Linear(4096, 102)),

        ('output', nn.LogSoftmax(dim=1))

    ]))

    # Put the classifier on the pretrained network

    model.classifier = classifier



    model.load_state_dict(chpt['state_dict'])



    return model





model = load_model('classifier.pth')

calc_accuracy(model, input_image_size=224, testset_path=valid_dir)





def process_image(image_path):

    """

    Scales, crops, and normalizes a PIL image for a PyTorch

    model, returns an Numpy array

    """

    # Open the image

    from PIL import Image

    img = Image.open(image_path)

    # Resize

    if img.size[0] > img.size[1]:

        img.thumbnail((10000, 256))

    else:

        img.thumbnail((256, 10000))

    # Crop

    left_margin = (img.width - 224) / 2

    bottom_margin = (img.height - 224) / 2

    right_margin = left_margin + 224

    top_margin = bottom_margin + 224

    img = img.crop((left_margin, bottom_margin, right_margin,

                    top_margin))

    # Normalize

    img = np.array(img) / 255

    mean = np.array([0.485, 0.456, 0.406])  # provided mean

    std = np.array([0.229, 0.224, 0.225])  # provided std

    img = (img - mean) / std



    # Move color channels to first dimension as expected by PyTorch

    img = img.transpose((2, 0, 1))



    return img





def imshow(image, ax=None, title=None):

    if ax is None:

        fig, ax = plt.subplots()

    if title:

        plt.title(title)

    # PyTorch tensors assume the color channel is first

    # but matplotlib assumes is the third dimension

    image = image.transpose((1, 2, 0))



    # Undo preprocessing

    mean = np.array([0.485, 0.456, 0.406])

    std = np.array([0.229, 0.224, 0.225])

    image = std * image + mean



    # Image needs to be clipped between 0 and 1

    image = np.clip(image, 0, 1)



    ax.imshow(image)



    return ax





# Class Prediction



def predict(image_path, model, top_num=5):

    """

    Predict the class of an image, given a model

    :param image_path:

    :param model:

    :param top_num:

    :return:

    """

    # Process image

    img = process_image(image_path)



    # Numpy -> Tensor

    image_tensor = torch.from_numpy(img).type(torch.FloatTensor)

    # Add batch of size 1 to image

    model_input = image_tensor.unsqueeze(0)



    image_tensor.to('cpu')

    model_input.to('cpu')

    model.to('cpu')



    # Probs

    probs = torch.exp(model.forward(model_input))



    # Top probs

    top_probs, top_labs = probs.topk(top_num)

    top_probs = top_probs.detach().numpy().tolist()[0]

    top_labs = top_labs.detach().numpy().tolist()[0]



    # Convert indices to classes

    idx_to_class = {val: key for key, val in

                    model.class_to_idx.items()}

    top_labels = [idx_to_class[lab] for lab in top_labs]

    top_flowers = [cat_to_name[idx_to_class[lab]] for lab in top_labs]

    return top_probs, top_labels, top_flowers





# Sanity Checking



def plot_solution(image_path, model):

    """

    Plot an image with the top 5 class prediction

    :param image_path:

    :param model:

    :return:

    """

    # Set up plot

    plt.figure(figsize=(6, 10))

    ax = plt.subplot(2, 1, 1)

    # Set up title

    flower_num = image_path.split('/')[3]

    title_ = cat_to_name[flower_num]

    # Plot flower

    img = process_image(image_path)

    imshow(img, ax, title=title_);

    # Make prediction

    probs, labs, flowers = predict(image_path, model)

    # Plot bar chart

    plt.subplot(2, 1, 2)

    sns.barplot(x=probs, y=flowers, color=sns.color_palette()[0]);

    plt.show()





image_path = os.path.join(valid_dir, '28/image_05265.jpg')

plot_solution(image_path, model)



# Publish the result on the Airtable shared leaderboard

publish_evaluated_model(model, input_image_size=224, username="@Slack.Username", model_name="VGG19", optim="Adam",

                        criteria="NLLLoss", scheduler="StepLR", epoch=5)