import json
import torch, argparse, os
import random
import numpy as np
import src.utils as utils
import src.dataset as dataset

from src.net.resnet import *
from src.net.googlenet import *
from src.net.bn_inception import *
from src.dataset import sampler
from torch.utils.data.sampler import BatchSampler
from torch.utils.data.dataloader import default_collate

from tqdm import *
import wandb

from src.losses.mean_field_contrastive_loss import MeanFieldContrastiveLoss
from src.losses.mean_field_multisimilarity_loss import (
    MeanFieldClassWiseMultiSimilarityLoss,
)
from pytorch_metric_learning.losses import ProxyAnchorLoss

import ssl

ssl._create_default_https_context = ssl._create_unverified_context

parser = argparse.ArgumentParser(
    description="Official implementation of `Proxy Anchor Loss for Deep Metric Learning`"
    + "Our code is modified from `https://github.com/dichotomies/proxy-nca`"
)
# export directory, training and val datasets, test datasets
parser.add_argument("--LOG_DIR", default="../logs", help="Path to log folder")
parser.add_argument(
    "--dataset", default="cub", help="Training dataset, e.g. cub, cars, SOP, Inshop"
)
parser.add_argument(
    "--embedding-size",
    default=512,
    type=int,
    dest="sz_embedding",
    help="Size of embedding that is appended to backbone model.",
)
parser.add_argument(
    "--batch-size",
    default=150,
    type=int,
    dest="sz_batch",
    help="Number of samples per batch.",
)
parser.add_argument(
    "--epochs",
    default=60,
    type=int,
    dest="nb_epochs",
    help="Number of training epochs.",
)
parser.add_argument(
    "--gpu-id", default=0, type=int, help="ID of GPU that is used for training."
)
parser.add_argument(
    "--workers",
    default=4,
    type=int,
    dest="nb_workers",
    help="Number of workers for dataloader.",
)
parser.add_argument("--model", default="bn_inception", help="Model for training")
parser.add_argument("--loss", default="Proxy_Anchor", help="Criterion for training")
parser.add_argument("--optimizer", default="adamw", help="Optimizer setting")
parser.add_argument("--lr", default=1e-4, type=float, help="Learning rate setting")
parser.add_argument(
    "--lr-ratio", default=100, type=float, help="lr for proxy / lr for embedder"
)
parser.add_argument(
    "--weight-decay", default=1e-4, type=float, help="Weight decay setting"
)
parser.add_argument(
    "--lr-decay-step", default=10, type=int, help="Learning decay step setting"
)
parser.add_argument(
    "--lr-decay-gamma", default=0.5, type=float, help="Learning decay gamma setting"
)
parser.add_argument("--alpha", type=float, help="Scaling Parameter setting")
parser.add_argument("--beta", type=float, help="Scaling Parameter setting")
parser.add_argument("--mrg", type=float, help="Margin parameter setting")
parser.add_argument("--IPC", type=int, help="Balanced sampling, images per class")
parser.add_argument("--warm", default=1, type=int, help="Warmup training epochs")
parser.add_argument(
    "--bn-freeze", default=1, type=int, help="Batch normalization parameter freeze"
)
parser.add_argument("--l2-norm", default=1, type=int, help="L2 normlization")
parser.add_argument("--remark", default="", help="Any reamrk")
parser.add_argument("--pos_mrg", type=float, help="Positive margin parameter setting")
parser.add_argument("--neg_mrg", type=float, help="Negative margin parameter setting")
parser.add_argument("--reg", type=float, help="Regularization parameter setting")
parser.add_argument("--pow", type=float, help="Regularization power parameter setting")
parser.add_argument("--patience", default=10, type=int, help="Early stopping patience")
parser.add_argument("--seed", default=1, type=int, help="Random seed")
args = parser.parse_args()

seed = args.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # set random seed for all gpus

if args.gpu_id != -1:
    torch.cuda.set_device(args.gpu_id)

# Directory for Log
LOG_DIR = (
    args.LOG_DIR
    + "/logs_{}/{}_{}_embedding{}_alpha{}_mrg{}_{}_lr{}_batch{}{}".format(
        args.dataset,
        args.model,
        args.loss,
        args.sz_embedding,
        args.alpha,
        args.mrg,
        args.optimizer,
        args.lr,
        args.sz_batch,
        args.remark,
    )
)
# Wandb Initialization
wandb.init(
    project=f"{args.dataset}_v1",
    notes=LOG_DIR,
    name=f"{args.loss}-LRR{args.lr_ratio}-B{args.sz_batch}",
)
wandb.config.update(args)

os.chdir("datasets")
data_root = os.getcwd()
# Dataset Loader and Sampler
if args.dataset != "Inshop":
    trn_dataset = dataset.load(
        name=args.dataset,
        root=data_root,
        mode="train",
        transform=dataset.utils.make_transform(
            is_train=True, is_inception=(args.model == "bn_inception")
        ),
    )
else:
    trn_dataset = dataset.Inshop.Inshop_Dataset(
        root=data_root,
        mode="train",
        transform=dataset.utils.make_transform(
            is_train=True, is_inception=(args.model == "bn_inception")
        ),
    )

if args.IPC:
    balanced_sampler = sampler.BalancedSampler(
        trn_dataset, batch_size=args.sz_batch, images_per_class=args.IPC
    )
    batch_sampler = BatchSampler(
        balanced_sampler, batch_size=args.sz_batch, drop_last=True
    )
    dl_tr = torch.utils.data.DataLoader(
        trn_dataset,
        num_workers=args.nb_workers,
        pin_memory=True,
        batch_sampler=batch_sampler,
    )
    print("Balanced Sampling")

else:
    dl_tr = torch.utils.data.DataLoader(
        trn_dataset,
        batch_size=args.sz_batch,
        shuffle=True,
        num_workers=args.nb_workers,
        drop_last=True,
        pin_memory=True,
    )
    print("Random Sampling")

if args.dataset != "Inshop":
    ev_dataset = dataset.load(
        name=args.dataset,
        root=data_root,
        mode="eval",
        transform=dataset.utils.make_transform(
            is_train=False, is_inception=(args.model == "bn_inception")
        ),
    )

    dl_ev = torch.utils.data.DataLoader(
        ev_dataset,
        batch_size=32,  # args.sz_batch,
        shuffle=False,
        num_workers=args.nb_workers,
        pin_memory=True,
    )

else:
    query_dataset = dataset.Inshop.Inshop_Dataset(
        root=data_root,
        mode="query",
        transform=dataset.utils.make_transform(
            is_train=False, is_inception=(args.model == "bn_inception")
        ),
    )

    dl_query = torch.utils.data.DataLoader(
        query_dataset,
        batch_size=32,  # args.sz_batch,
        shuffle=False,
        num_workers=args.nb_workers,
        pin_memory=True,
    )

    gallery_dataset = dataset.Inshop.Inshop_Dataset(
        root=data_root,
        mode="gallery",
        transform=dataset.utils.make_transform(
            is_train=False, is_inception=(args.model == "bn_inception")
        ),
    )

    dl_gallery = torch.utils.data.DataLoader(
        gallery_dataset,
        batch_size=32,  # args.sz_batch,
        shuffle=False,
        num_workers=args.nb_workers,
        pin_memory=True,
    )

nb_classes = trn_dataset.nb_classes()

# Backbone Model
if args.model.find("googlenet") + 1:
    model = googlenet(
        embedding_size=args.sz_embedding,
        pretrained=True,
        is_norm=args.l2_norm,
        bn_freeze=args.bn_freeze,
    )
elif args.model.find("bn_inception") + 1:
    model = bn_inception(
        embedding_size=args.sz_embedding,
        pretrained=True,
        is_norm=args.l2_norm,
        bn_freeze=args.bn_freeze,
    )
elif args.model.find("resnet18") + 1:
    model = Resnet18(
        embedding_size=args.sz_embedding,
        pretrained=True,
        is_norm=args.l2_norm,
        bn_freeze=args.bn_freeze,
    )
elif args.model.find("resnet50") + 1:
    model = Resnet50(
        embedding_size=args.sz_embedding,
        pretrained=True,
        is_norm=args.l2_norm,
        bn_freeze=args.bn_freeze,
    )
elif args.model.find("resnet101") + 1:
    model = Resnet101(
        embedding_size=args.sz_embedding,
        pretrained=True,
        is_norm=args.l2_norm,
        bn_freeze=args.bn_freeze,
    )
model = model.cuda()

if args.gpu_id == -1:
    model = nn.DataParallel(model)

# DML Losses
if args.loss == "Proxy_Anchor":
    criterion = ProxyAnchorLoss(
        num_classes=nb_classes, embedding_size=args.sz_embedding
    ).cuda()
elif args.loss == "MeanFieldContrastive":
    criterion = MeanFieldContrastiveLoss(
        num_classes=nb_classes,
        embedding_size=args.sz_embedding,
        pos_margin=args.pos_mrg,
        neg_margin=args.neg_mrg,
        mf_reg=args.reg,
        mf_power=args.pow,
    ).cuda()
elif args.loss == "MeanFieldClassWiseMultiSimilarity":
    criterion = MeanFieldClassWiseMultiSimilarityLoss(
        num_classes=nb_classes,
        embedding_size=args.sz_embedding,
        alpha=args.alpha,
        beta=args.beta,
        base=args.mrg,
        mf_reg=args.reg,
        mf_power=args.pow,
    ).cuda()

# Train Parameters
param_groups = [
    {
        "params": list(
            set(model.parameters()).difference(set(model.model.embedding.parameters()))
        )
        if args.gpu_id != -1
        else list(
            set(model.module.parameters()).difference(
                set(model.module.model.embedding.parameters())
            )
        )
    },
    {
        "params": model.model.embedding.parameters()
        if args.gpu_id != -1
        else model.module.model.embedding.parameters(),
        "lr": float(args.lr) * 1,
    },
]
if args.loss == "Proxy_Anchor":
    param_groups.append(
        {"params": criterion.parameters(), "lr": float(args.lr) * args.lr_ratio}
    )
elif args.loss == "MeanFieldContrastive":
    param_groups.append(
        {"params": criterion.parameters(), "lr": float(args.lr) * args.lr_ratio}
    )
elif args.loss == "MeanFieldClassWiseMultiSimilarity":
    param_groups.append(
        {"params": criterion.parameters(), "lr": float(args.lr) * args.lr_ratio}
    )

# Optimizer Setting
if args.optimizer == "sgd":
    opt = torch.optim.SGD(
        param_groups,
        lr=float(args.lr),
        weight_decay=args.weight_decay,
        momentum=0.9,
        nesterov=True,
    )
elif args.optimizer == "adam":
    opt = torch.optim.Adam(
        param_groups, lr=float(args.lr), weight_decay=args.weight_decay
    )
elif args.optimizer == "rmsprop":
    opt = torch.optim.RMSprop(
        param_groups,
        lr=float(args.lr),
        alpha=0.9,
        weight_decay=args.weight_decay,
        momentum=0.9,
    )
elif args.optimizer == "adamw":
    opt = torch.optim.AdamW(
        param_groups, lr=float(args.lr), weight_decay=args.weight_decay
    )

scheduler = torch.optim.lr_scheduler.StepLR(
    opt, step_size=args.lr_decay_step, gamma=args.lr_decay_gamma
)

print("Training parameters: {}".format(vars(args)))
print("Training for {} epochs.".format(args.nb_epochs))
losses_list = []
best_metrics = {"mean_average_precision_at_r": 0}  # best_recall = [0]
abb = {
    "mean_average_precision_at_r": "MAP@R",
    "precision_at_1": "P@1",
    "r_precision": "RP",
}
best_epoch = 0
counter = 0

for epoch in range(0, args.nb_epochs):
    model.train()
    bn_freeze = args.bn_freeze
    if bn_freeze:
        modules = (
            model.model.modules() if args.gpu_id != -1 else model.module.model.modules()
        )
        for m in modules:
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    losses_per_epoch = []

    # Warmup: Train only new params, helps stabilize learning.
    if args.warm > 0:
        if args.gpu_id != -1:
            unfreeze_model_param = list(model.model.embedding.parameters()) + list(
                criterion.parameters()
            )
        else:
            unfreeze_model_param = list(
                model.module.model.embedding.parameters()
            ) + list(criterion.parameters())

        if epoch == 0:
            for param in list(
                set(model.parameters()).difference(set(unfreeze_model_param))
            ):
                param.requires_grad = False
        if epoch == args.warm:
            for param in list(
                set(model.parameters()).difference(set(unfreeze_model_param))
            ):
                param.requires_grad = True

    pbar = tqdm(enumerate(dl_tr))

    for batch_idx, (x, y) in pbar:
        m = model(x.squeeze().cuda())
        loss = criterion(m, y.squeeze().cuda())

        opt.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_value_(model.parameters(), 10)
        if args.loss == "Proxy_Anchor":
            torch.nn.utils.clip_grad_value_(criterion.parameters(), 10)

        losses_per_epoch.append(loss.data.cpu().numpy())
        opt.step()

        pbar.set_description(
            "Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}".format(
                epoch,
                batch_idx + 1,
                len(dl_tr),
                100.0 * batch_idx / len(dl_tr),
                loss.item(),
            )
        )

    losses_list.append(np.mean(losses_per_epoch))
    wandb.log({"loss": losses_list[-1]}, step=epoch)
    scheduler.step()

    if epoch >= 0:
        with torch.no_grad():
            print("**Evaluating...**")
            if args.dataset == "Inshop":
                pml_acc = utils.evaluate_cos_Inshop(model, dl_query, dl_gallery)
            elif args.dataset != "SOP":
                pml_acc = utils.evaluate_cos(model, dl_ev)
            else:
                pml_acc = utils.evaluate_cos_SOP(model, dl_ev)

        # Logging Evaluation Score

        if args.dataset == "Inshop":
            for key, val in pml_acc.items():
                wandb.log({f"{abb[key]}": val}, step=epoch)
        elif args.dataset != "SOP":
            for key, val in pml_acc.items():
                wandb.log({f"{abb[key]}": val}, step=epoch)
        else:
            for key, val in pml_acc.items():
                wandb.log({f"{abb[key]}": val}, step=epoch)

        # Best model save
        if (
            best_metrics["mean_average_precision_at_r"]
            < pml_acc["mean_average_precision_at_r"]
        ):  # if best_recall[0] < Recalls[0]:
            counter = 0
            best_metrics = pml_acc
            best_epoch = epoch

            for key, val in pml_acc.items():
                wandb.log({f"best {abb[key]}": val}, step=epoch)
            wandb.log({f"best epoch": epoch}, step=epoch)

            if not os.path.exists("{}".format(LOG_DIR)):
                os.makedirs("{}".format(LOG_DIR))
            torch.save(
                {"model_state_dict": model.state_dict()},
                "{}/{}_{}_best.pth".format(LOG_DIR, args.dataset, args.model),
            )
            with open(
                "{}/{}_{}_best_results.json".format(LOG_DIR, args.dataset, args.model),
                "w",
            ) as f:
                best_metrics["best_epoch"] = best_epoch
                # f.write("Best Epoch: {}\n".format(best_epoch))
                if args.dataset == "Inshop":
                    json.dump(best_metrics, f, indent=4)
                elif args.dataset != "SOP":
                    json.dump(best_metrics, f, indent=4)
                else:
                    json.dump(best_metrics, f, indent=4)
        else:
            counter += 1
            print(f"Best MAP@R was not updated... Early stop counter gets {counter}")
            if counter > args.patience:
                break
