from utils.tools import *
from utils.BalanceAnalysis import balance_analysis
from network import *

import os
import torch
import torch.optim as optim
import time
import numpy as np
from scipy.linalg import hadamard  # direct import  hadamrd matrix from scipy
import random
from tqdm import tqdm
from functools import partialmethod
import argparse
import pandas as pd
from losses.distributional_quantization_losses import quantization_ot_loss, quantization_swdc_loss, quantization_swd_loss, conditional_transport


torch.multiprocessing.set_sharing_strategy('file_system')


# CSQ(CVPR2020)
# paper [Central Similarity Quantization for Efficient Image and Video Retrieval](https://openaccess.thecvf.com/content_CVPR_2020/papers/Yuan_Central_Similarity_Quantization_for_Efficient_Image_and_Video_Retrieval_CVPR_2020_paper.pdf)
# code [CSQ-pytorch](https://github.com/yuanli2333/Hadamard-Matrix-for-hashing)

# AlexNet
# [CSQ] epoch:65, bit:64, dataset:cifar10-1, MAP:0.787, Best MAP: 0.790
# [CSQ] epoch:90, bit:16, dataset:imagenet, MAP:0.593, Best MAP: 0.596, paper:0.601
# [CSQ] epoch:150, bit:64, dataset:imagenet, MAP:0.698, Best MAP: 0.706, paper:0.695
# [CSQ] epoch:40, bit:16, dataset:nuswide_21, MAP:0.784, Best MAP: 0.789
# [CSQ] epoch:40, bit:32, dataset:nuswide_21, MAP:0.821, Best MAP: 0.821
# [CSQ] epoch:40, bit:64, dataset:nuswide_21, MAP:0.834, Best MAP: 0.834

# ResNet50
# [CSQ] epoch:20, bit:64, dataset:imagenet, MAP:0.881, Best MAP: 0.881, paper:0.873
# [CSQ] epoch:10, bit:64, dataset:nuswide_21_m, MAP:0.844, Best MAP: 0.844, paper:0.839
# [CSQ] epoch:40, bit:64, dataset:coco, MAP:0.870, Best MAP: 0.883, paper:0.861

def get_args():
    parser = argparse.ArgumentParser()
    #Todo: inplement pairwise dataset class for CIFAR-10 dataset
    parser.add_argument('--pairwise', action='store_true')
    parser.add_argument('--no-pairwise', dest='pairwise', action='store_false')
    parser.set_defaults(pairwise=True)
    parser.add_argument('--tqdm', action='store_true')
    parser.add_argument('--no-tqdm', dest='tqdm', action='store_false')
    parser.add_argument('--quantization_type', type=str, default=None, help="")
    parser.add_argument('--dataset', type = str, default="imagenet", help = " ")
    parser.add_argument('--ct_phi', type = float, default = 0, help = " ")
    parser.add_argument('--lambda', type = float, default = 0.1, help = " ")
    parser.set_defaults(tqdm=False)
    parser.add_argument('--rt_st', type = str, default = "ham", help = "Retrieval Stratege, \"ham\" or \"hamuct\"")
    parser.add_argument('--val', action='store_true')
    parser.add_argument('--no-val', dest='val', action='store_false')
    parser.set_defaults(val=False)   
    args = parser.parse_args()
    print('Arguments:', args)
    return args

def get_config():
    config = {
        # "lambda": 0.1,
        "optimizer": {"type": optim.RMSprop, "optim_params": {"lr": 1e-5, "weight_decay": 10 ** -5}},
        "info": "[CSQ]",
        "resize_size": 256,
        "crop_size": 224,
        "batch_size": 64,
        "net": AlexNet,
        # "net": ResNet,
        # "dataset": "cifar10-1",
        # "dataset": "imagenet",
        # "dataset": "coco",
        # "dataset": "nuswide_21_m",
        # "dataset": "nuswide_21_m",
        # "dataset": "nuswide_81_m",
        "epoch": 100,
        "test_map": 10,
        # "device":torch.device("cpu"),
        "device": torch.device("cuda:0"),
        "bit_list": [16, 32, 64],
        "seed": 1, 
    }
    return config


class CSQLoss(torch.nn.Module):
    def __init__(self, config, bit):
        super(CSQLoss, self).__init__()
        self.is_single_label = config["dataset"] not in {"nuswide_21", "nuswide_21_m", "coco"}
        self.hash_targets = self.get_hash_targets(config["n_class"], bit).to(config["device"])
        self.multi_label_random_center = torch.randint(2, (bit,)).float().to(config["device"])
        self.criterion = torch.nn.BCELoss().to(config["device"])
        self.CT_loss = conditional_transport(rho = 0.5, phi = config["ct_phi"])

    def forward(self, u, y, ind, config):
        u = u.tanh()
        hash_center = self.label2center(y)
        center_loss = self.criterion(0.5 * (u + 1), 0.5 * (hash_center + 1))
        # Q_loss = (u.abs() - 1).pow(2).mean()
        if config["quantization_type"] == 'ot':
            quantization_loss = quantization_ot_loss(u)
        elif config["quantization_type"] == 'swd':
            quantization_loss = quantization_swd_loss(u.view(u.size(0), -1))
        elif config["quantization_type"] == 'swdC':
            quantization_loss = quantization_swdc_loss(u.view(u.size(0), -1))
        elif config["quantization_type"] == 'CT':
            real_b = torch.randn(u.shape, device=config["device"]).sign()
            quantization_loss = self.CT_loss.compute_ct(u, real_b)/config["batch_size"]
        elif config["quantization_type"] == 'CT_E':
            real_b = torch.randn(u.shape, device=config["device"]).sign()
            quantization_loss = self.CT_loss.compute_entry_ct_v2(u, real_b)/config["batch_size"]
        return center_loss, quantization_loss

    def label2center(self, y):
        if self.is_single_label:
            hash_center = self.hash_targets[y.argmax(axis=1)]
        else:
            # to get sign no need to use mean, use sum here
            center_sum = y @ self.hash_targets
            random_center = self.multi_label_random_center.repeat(center_sum.shape[0], 1)
            center_sum[center_sum == 0] = random_center[center_sum == 0]
            hash_center = 2 * (center_sum > 0).float() - 1
        return hash_center

    # use algorithm 1 to generate hash centers
    def get_hash_targets(self, n_class, bit):
        H_K = hadamard(bit)
        H_2K = np.concatenate((H_K, -H_K), 0)
        hash_targets = torch.from_numpy(H_2K[:n_class]).float()

        if H_2K.shape[0] < n_class:
            hash_targets.resize_(n_class, bit)
            for k in range(20):
                for index in range(H_2K.shape[0], n_class):
                    ones = torch.ones(bit)
                    # Bernouli distribution
                    sa = random.sample(list(range(bit)), bit // 2)
                    ones[sa] = -1
                    hash_targets[index] = ones
                # to find average/min  pairwise distance
                c = []
                for i in range(n_class):
                    for j in range(n_class):
                        if i < j:
                            TF = sum(hash_targets[i] != hash_targets[j])
                            c.append(TF)
                c = np.array(c)

                # choose min(c) in the range of K/4 to K/3
                # see in https://github.com/yuanli2333/Hadamard-Matrix-for-hashing/issues/1
                # but it is hard when bit is  small
                if c.min() > bit / 4 and c.mean() >= bit / 2:
                    print(c.min(), c.mean())
                    break
        return hash_targets


def train_val(config, bit):
    device = config["device"]

    random.seed(config["seed"])
    np.random.seed(config["seed"])
    torch.manual_seed(config["seed"])

    if not config["tqdm"]:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    save_path = os.path.join(".", "saved_model", "HSWD_bit_%i_dataset_%s_ottype_%s_ctphi_%.2f_lambda_%.3f" %(bit, config["dataset"], config["quantization_type"], config["ct_phi"], config["lambda"]))
    print("save_path", save_path)

    if not os.path.exists(os.path.join(".", "saved_model")):
        os.mkdir(os.path.join(".", "saved_model"))

    if not os.path.exists(save_path):
        os.mkdir(save_path)

    training_loss_list = []
    val_loss_list = []
    map_list = []
    ot_list = []

    # train_loader, test_loader, dataset_loader, num_train, num_test, num_dataset = get_data(config)
    if config["val"] == True:
        train_loader, val_loader, test_loader, dataset_loader, num_train, num_val, num_test, num_dataset = get_data(config)
    else:
        train_loader, test_loader, dataset_loader, num_train, num_test, num_dataset = get_data(config)

    config["num_train"] = num_train
    net = config["net"](bit).to(device)

    optimizer = config["optimizer"]["type"](net.parameters(), **(config["optimizer"]["optim_params"]))

    criterion = CSQLoss(config, bit)

    Best_mAP = 0

    for epoch in range(config["epoch"]):

        current_time = time.strftime('%H:%M:%S', time.localtime(time.time()))

        print("%s[%2d/%2d][%s] bit:%d, dataset:%s, training...." % (
            config["info"], epoch + 1, config["epoch"], current_time, bit, config["dataset"]), end="")

        net.train()

        train_loss = 0
        train_quant_loss = 0
        for image, label, ind in train_loader:
            image = image.to(device)
            label = label.to(device)

            optimizer.zero_grad()
            u = net(image)

            loss, quant_loss = criterion(u, label.float(), ind, config)
            sum_loss = loss + config["lambda"]*quant_loss
            train_loss += loss.item()
            train_quant_loss += quant_loss.item()

            sum_loss.backward()
            optimizer.step()

        train_loss = train_loss / len(train_loader)
        train_quant_loss = train_quant_loss / len(train_loader)

        print("\b\b\b\b\b\b\b loss:%.3f quant_loss%.3f" % (train_loss, train_quant_loss))

        if (epoch + 1) % config["test_map"] == 0:
            if config["val"] == True:
                Best_mAP, mAP = self_validate(config, Best_mAP, val_loader, net, bit, epoch, num_val, return_map = True)
            else:
                Best_mAP, mAP = validate(config, Best_mAP, test_loader, dataset_loader, net, bit, epoch, num_dataset, return_map = True)
            
            balance_analysis(train_loader, net)
            val_loss = 0
            if config["val"] == True:
                with torch.no_grad():
                    for data in val_loader:
                        (image, label, ind) = data
                        image = image.to(device)
                        label = label.to(device)
                        u = net(image)
                        loss, quant_loss = criterion(u, label.float(), ind, config)
                        val_loss += loss.item() + config["lambda"]*quant_loss.item()
                    val_loss = val_loss / len(val_loader)
            torch.save(net.state_dict(), os.path.join(save_path, "epoch_%i.pth" %(epoch)))

            training_loss_list.append(train_loss)
            val_loss_list.append(val_loss)
            map_list.append(mAP)
            ot_list.append(train_quant_loss)
#Training Done    
    Best_Epoch = (np.argmax(map_list) + 1)*config["test_map"]
    Best_Model_Path = os.path.join(save_path, "epoch_%i.pth" %(Best_Epoch - 1))

    net.load_state_dict(torch.load(Best_Model_Path))
    Best_mAP = 0
    Best_mAP, mAP = validate(config, Best_mAP, test_loader, dataset_loader, net, bit, Best_Epoch - 1, num_dataset, return_map = True)
    pd.DataFrame(training_loss_list).to_csv(os.path.join(save_path, "val_loss.csv"))
    pd.DataFrame(training_loss_list).to_csv(os.path.join(save_path, "train_loss.csv"))
    pd.DataFrame(map_list).to_csv(os.path.join(save_path, "map.csv"))
    pd.DataFrame(ot_list).to_csv(os.path.join(save_path, "ot.csv"))


if __name__ == "__main__":
    config = get_config()
    print(config)
    args = get_args()
    config.update(vars(args)) 
    config = config_dataset(config)   
    for bit in config["bit_list"]:
        config["pr_curve_path"] = f"log/alexnet/CSQ_{config['dataset']}_{bit}.json"
        train_val(config, bit)
