from utils.tools import *
from network import *

import os
import torch
import torch.optim as optim
import time
import numpy as np
import random
import argparse
import pandas as pd
from functools import partialmethod


from tqdm import tqdm
torch.multiprocessing.set_sharing_strategy('file_system')


# HashNet(ICCV2017)
# paper [HashNet: Deep Learning to Hash by Continuation](http://openaccess.thecvf.com/content_ICCV_2017/papers/Cao_HashNet_Deep_Learning_ICCV_2017_paper.pdf)
# code [HashNet caffe and pytorch](https://github.com/thuml/HashNet)
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--pairwise', action='store_true')
    parser.add_argument('--no-pairwise', dest='pairwise', action='store_false')
    parser.set_defaults(pairwise=True)
    parser.add_argument('--tqdm', action='store_true')
    parser.add_argument('--no-tqdm', dest='tqdm', action='store_false')
    parser.add_argument('--dataset', type = str, default="imagenet", help = " ")
    parser.add_argument('--rt_st', type = str, default = "ham", help = "Retrieval Stratege, \"ham\" or \"hamuct\"")
    parser.add_argument('--val', action='store_true')
    parser.add_argument('--no-val', dest='val', action='store_false')
    parser.set_defaults(val=False)    
    parser.set_defaults(tqdm=False)
    args = parser.parse_args()
    print('Arguments:', args)
    return args


def get_config():
    config = {
        "alpha": 0.1,
        # "optimizer":{"type":  optim.SGD, "optim_params": {"lr": 0.001, "weight_decay": 10 ** -5}, "lr_type": "step"},
        "optimizer": {"type": optim.RMSprop, "optim_params": {"lr": 1e-5, "weight_decay": 10 ** -5}, "lr_type": "step"},
        "info": "[HashNet]",
        "step_continuation": 20,
        "resize_size": 256,
        "crop_size": 224,
        "batch_size": 64,
        "net": AlexNet,
        # "net":ResNet,
        # "dataset": "cifar10",
        # "dataset": "cifar10-1",
        # "dataset": "cifar10-2",
        # "dataset": "coco",
        # "dataset": "mirflickr",
        # "dataset": "voc2012",
        "dataset": "imagenet",
        # "dataset": "nuswide_21",
        # "dataset": "nuswide_21_m",
        # "dataset": "nuswide_81_m",
        "epoch": 100,
        "test_map": 10,
        "save_path": "save/HashNet",
        # "device":torch.device("cpu"),
        "device": torch.device("cuda:0"),
        "bit_list": [16, 32, 64],
        "seed": 0, 
    }
    # config = config_dataset(config)
    return config


class HashNetLoss(torch.nn.Module):
    def __init__(self, config, bit):
        super(HashNetLoss, self).__init__()
        self.scale = 1

    def forward(self, u, y, ind, config):
        u = torch.tanh(self.scale * u)
        S = (y @ y.t() > 0).float()
        sigmoid_alpha = config["alpha"]
        dot_product = sigmoid_alpha * u @ u.t()
        mask_positive = S > 0
        mask_negative = (1 - S).bool()
        
        neg_log_probe = dot_product + torch.log(1 + torch.exp(-dot_product)) -  S * dot_product
        S1 = torch.sum(mask_positive.float())
        S0 = torch.sum(mask_negative.float())
        S = S0 + S1

        neg_log_probe[mask_positive] = neg_log_probe[mask_positive] * S / S1
        neg_log_probe[mask_negative] = neg_log_probe[mask_negative] * S / S0

        loss = torch.sum(neg_log_probe) / S
        return loss


def train_val(config, bit):
    device = config["device"]

    random.seed(config["seed"])
    np.random.seed(config["seed"])
    torch.manual_seed(config["seed"])
    if not config["tqdm"]:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    save_path = os.path.join(".", "saved_model", "HashNet_bit_%i_dataset_%s" %(bit, config["dataset"]))
    print("save_path", save_path)

    if not os.path.exists(os.path.join(".", "saved_model")):
        os.mkdir(os.path.join(".", "saved_model"))

    if not os.path.exists(save_path):
        os.mkdir(save_path)

    map_list = []
    val_loss_list = []

    # train_loader, test_loader, dataset_loader, num_train, num_test, num_dataset = get_data(config)
    if config["val"] == True:
        train_loader, val_loader, test_loader, dataset_loader, num_train, num_val, num_test, num_dataset = get_data(config)
    else:
        train_loader, test_loader, dataset_loader, num_train, num_test, num_dataset = get_data(config)

    config["num_train"] = num_train
    net = config["net"](bit).to(device)

    optimizer = config["optimizer"]["type"](net.parameters(), **(config["optimizer"]["optim_params"]))

    criterion = HashNetLoss(config, bit)

    Best_mAP = 0

    for epoch in range(config["epoch"]):
        criterion.scale = (epoch // config["step_continuation"] + 1) ** 0.5

        current_time = time.strftime('%H:%M:%S', time.localtime(time.time()))

        print("%s[%2d/%2d][%s] bit:%d, dataset:%s, scale:%.3f, training...." % (
            config["info"], epoch + 1, config["epoch"], current_time, bit, config["dataset"], criterion.scale), end="")

        net.train()

        train_loss = 0
        for image, label, ind in train_loader:

            image = image.to(device)
            label = label.to(device)

            optimizer.zero_grad()
            u = net(image)

            loss = criterion(u, label.float(), ind, config)
            train_loss += loss.item()

            loss.backward()
            optimizer.step()

        train_loss = train_loss / len(train_loader)

        print("\b\b\b\b\b\b\b loss:%.3f" % (train_loss))

        if (epoch + 1) % config["test_map"] == 0:
            # Best_mAP = validate(config, Best_mAP, test_loader, dataset_loader, net, bit, epoch, num_dataset)
            if config["val"] == True:
                Best_mAP, mAP = self_validate(config, Best_mAP, val_loader, net, bit, epoch, num_val, return_map = True)
            else:
                Best_mAP, mAP = validate(config, Best_mAP, test_loader, dataset_loader, net, bit, epoch, num_dataset, return_map = True)

            val_loss = 0
            if config["val"] == True:
                with torch.no_grad():
                    for data in val_loader:
                        (image, label, ind) = data
                        image = image.to(device)
                        label = label.to(device)
                        u = net(image)
                        loss = criterion(u, label.float(), ind, config)
                        val_loss += loss.item()
                    val_loss = val_loss / len(val_loader)

            torch.save(net.state_dict(), os.path.join(save_path, "epoch_%i.pth" %(epoch)))
            val_loss_list.append(val_loss)
            map_list.append(mAP)

    Best_Epoch = (np.argmax(map_list) + 1)*config["test_map"]
    Best_Model_Path = os.path.join(save_path, "epoch_%i.pth" %(Best_Epoch - 1))

    net.load_state_dict(torch.load(Best_Model_Path))
    Best_mAP = 0
    Best_mAP, mAP = validate(config, Best_mAP, test_loader, dataset_loader, net, bit, Best_Epoch - 1, num_dataset, return_map = True)

    # pd.DataFrame(training_loss_list).to_csv(os.path.join(save_path, "train_loss.csv"))
    pd.DataFrame(map_list).to_csv(os.path.join(save_path, "map.csv"))


if __name__ == "__main__":
    config = get_config()
    print(config)
    args = get_args()
    config.update(vars(args))
    config = config_dataset(config) 
    for bit in config["bit_list"]:
        # config["pr_curve_path"] = f"log/alexnet/HashNet_{config['dataset']}_{bit}.json"
        train_val(config, bit)
