from collections import OrderedDict
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.tensorboard as tb
import json
from time import time
from metric import Metric

from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from pmath import pair_wise_eud, pair_wise_cos, pair_wise_hyp
from utils import get_son2parent

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--n_bits', type=int, default=4)
parser.add_argument('--hidden_dim', type=int, default=1024)
parser.add_argument('--emb_path', type=str, default='embs/cifar100_hierarchy_128d.pth')
parser.add_argument('--epochs', type=int, default=200)
# add dataset argument, does not allow default value, must be specified, choices are cifar100 and imagenet
parser.add_argument('--dataset', type=str, choices=['cifar100', 'imagenet', 'mim'], default='cifar100', help='dataset name')
parser.add_argument('--c', type=float, default=0.1, help='curvature for image embeddings')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--workers', type=int, default=0, help='number of workers for dataloader')
parser.add_argument('--batch_size', type=int, default=8192, help='batch size')

args = parser.parse_args()

def flatten_hierarchy(hierarchy, parent_id=None):

    result = []

    for node in hierarchy['children']:

        if 'index' not in node:
            node['index'] = None

        result.append({'id': node['id'], 'name': node['name'], 'index': node['index'], 'parent_id': parent_id})
        if 'children' in node:
            children_result_list = flatten_hierarchy(node, parent_id = node['id'])
            result.extend(children_result_list)
    return result

def get_id2idx(imnet_json):
    root_id = 'Root'
    # Step 1: load the hierarchy from json file and set the root node
    with open(imnet_json, 'r') as f:
        hierarchy = json.load(f)
        hierarchy['name'] = 'Root'
        hierarchy['id'] = root_id
        hierarchy['index'] = None
        hierarchy['parent_id'] = None

        sub_root ={'name': 'SubRoot', 'id': 'SubRoot', 'index': None, 'parent_id': root_id}
        hierarchy['children'].append(sub_root)

    # Step 2: flatten the hierarchy and fix the parent_id of the root node
    flattened_hierarchy = flatten_hierarchy(hierarchy)
    root_node = {'name': 'Root', 'id': root_id, 'index': None, 'parent_id': None}
    flattened_hierarchy.insert(0, root_node)

    # count how many nodes has a parent_id of None
    no_parent_list =[node for node in flattened_hierarchy if node['parent_id'] is None]

    print('There are', len(no_parent_list), 'nodes with parent_id of None')

    # for all nodes with parent_id of None, set their parent_id to be the root_id, except for the root node
    for node in flattened_hierarchy:
        if node['parent_id'] is None and node['id'] != root_id:
            node['parent_id'] = root_id

    no_parent_list =[node for node in flattened_hierarchy if node['parent_id'] is None]
    print('There are', len(no_parent_list), 'nodes with parent_id of None')

    # Step 3: create a dictionary mapping from id to name, and a dictionary mapping from id to parent_id
    # create a dictionary mapping from id to name
    id2name = {node['id']: node['name'] for node in flattened_hierarchy}

    # create a dictionary mapping from id to index
    id2idx = {node['id']: node['index'] for node in flattened_hierarchy if node['index'] is not None}

    return id2idx

def loss_fn(y, Apred, dist_func, c, T):

    logits = -dist_func(Apred, embs_cuda ,c) / T
    loss = F.cross_entropy(logits, y)
    
    return loss

# create dataset and dataloader
class FeatureDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __getitem__(self, index):

        feat = self.X[index]
        label = self.y[index]
        emb = embs[label]

        return feat, label, emb
    
    def __len__(self):
        return len(self.X)

top_K = 10


r = np.sqrt(1/args.c)
writer = tb.SummaryWriter(log_dir=f'runs/{args.dataset}'.format(args.dataset))

if args.dataset == "cifar100":
    save_dict = torch.load("cifar100-clip-features.pt")
    # save_dict = torch.load("cifar100-resnet152-features.pt")
    Xtr, Xte = save_dict['train_features'], save_dict['test_features']
    ytr, yte = save_dict['train_labels'], save_dict['test_labels']
    train_image_idices, test_image_idices = save_dict['train_image_idices'], save_dict['test_image_idices']
    clsid2name = {0: 'apple', 1: 'aquarium_fish', 2: 'baby', 3: 'bear', 4: 'beaver', 5: 'bed', 6: 'bee', 7: 'beetle', 
    8: 'bicycle', 9: 'bottle', 10: 'bowl', 11: 'boy', 12: 'bridge', 13: 'bus', 14: 'butterfly', 15: 'camel', 
    16: 'can', 17: 'castle', 18: 'caterpillar', 19: 'cattle', 20: 'chair', 21: 'chimpanzee', 22: 'clock', 
    23: 'cloud', 24: 'cockroach', 25: 'couch', 26: 'crab', 27: 'crocodile', 28: 'cup', 29: 'dinosaur', 
    30: 'dolphin', 31: 'elephant', 32: 'flatfish', 33: 'forest', 34: 'fox', 35: 'girl', 36: 'hamster', 
    37: 'house', 38: 'kangaroo', 39: 'keyboard', 40: 'lamp', 41: 'lawn_mower', 42: 'leopard', 43: 'lion', 
    44: 'lizard', 45: 'lobster', 46: 'man', 47: 'maple_tree', 48: 'motorcycle', 49: 'mountain', 50: 'mouse', 
    51: 'mushroom', 52: 'oak_tree', 53: 'orange', 54: 'orchid', 55: 'otter', 56: 'palm_tree', 57: 'pear', 
    58: 'pickup_truck', 59: 'pine_tree', 60: 'plain', 61: 'plate', 62: 'poppy', 63: 'porcupine', 64: 'possum', 
    65: 'rabbit', 66: 'raccoon', 67: 'ray', 68: 'road', 69: 'rocket', 70: 'rose', 71: 'sea', 72: 'seal', 
    73: 'shark', 74: 'shrew', 75: 'skunk', 76: 'skyscraper', 77: 'snail', 78: 'snake', 79: 'spider', 
    80: 'squirrel', 81: 'streetcar', 82: 'sunflower', 83: 'sweet_pepper', 84: 'table', 85: 'tank', 
    86: 'telephone', 87: 'television', 88: 'tiger', 89: 'tractor', 90: 'train', 91: 'trout', 92: 'tulip', 
    93: 'turtle', 94: 'wardrobe', 95: 'whale', 96: 'willow_tree', 97: 'wolf', 98: 'woman', 99: 'worm'}
    name2clsid = {v: k for k, v in clsid2name.items()}
    # sorted list of class names
    hierarchy_csv = 'cifar100_hierarchy.csv'
    label_set = [clsid2name[i] for i in range(len(clsid2name))]
    
elif args.dataset == "imagenet":
    train_save_dict = torch.load("feat/imagenet_train_clip_features.pt")
    test_save_dict = torch.load("feat/imagenet_val_clip_features.pt")
    Xtr, ytr = train_save_dict['features'], train_save_dict['labels']
    Xte, yte = test_save_dict['features'], test_save_dict['labels']

    imnet_json = 'imagenet_hierarchy.json'
    hierarchy_csv = 'imagenet_hierarchy.csv'
    name2clsid = get_id2idx(imnet_json)
    clsid2name = {v: k for k, v in name2clsid.items()}
    label_set = [clsid2name[i] for i in range(len(clsid2name))]


elif args.dataset == "mim":

    data_train = torch.load('feat/moments_train.pth')
    data_val = torch.load('feat/moments_val.pth')

    label_set = list(set(data_train['label']))
    label_set.sort() # 有Sort很重要
    
    Xtr, Xte = data_train['feat'], data_val['feat']
    ytr = torch.tensor([label_set.index(item )for item in data_train['label']])
    yte = torch.tensor([label_set.index(item )for item in data_val['label']])
    
    name2clsid = {v: k for k, v in enumerate(label_set)}
    hierarchy_csv = 'moments_depth_v5.csv'


son2parent = get_son2parent(hierarchy_csv)
emb_data = torch.load(args.emb_path)
embs_preorder = emb_data['embeddings']
names_preorder = emb_data['objects']

embs = torch.zeros((len(name2clsid), embs_preorder.shape[1]))
for i, name in enumerate(names_preorder):
    if name in name2clsid:
        embs[name2clsid[name], :] = embs_preorder[i]
embs_cuda = copy.deepcopy(embs).cuda()

if torch.sum(embs == 0) > 0:
    raise ValueError("Some classes are missing in the embedding file.")

Xtr, ytr, Xte, yte = Xtr.cpu(), ytr.cpu(), Xte.cpu(), yte.cpu()
train_dataset = FeatureDataset(Xtr, ytr)
test_dataset = FeatureDataset(Xte, yte)

train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.relu = nn.LeakyReLU()

        self.fc1 = nn.Linear(input_dim, 1024)
        # self.bn1 = nn.BatchNorm1d(1024)
        self.fc2 = nn.Linear(1024, hidden_dim)
        # self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        # x = self.bn1(x)
        x=  self.relu(self.fc2(x))
        # x = self.bn2(x)
        x = self.fc3(x)
        if (x.norm(dim=1) >= r).any():
            x = r * x / (x.norm(dim=1,keepdim=True) + 1e-2)

        return x

    
def get_topK_preds(X, y, top_K):

    sim_score = pair_wise_cos(X, X)

    _, indices = torch.sort(sim_score, descending=False, dim=1)

    top_k_preds = y[indices[:, 1:top_K + 1]]

    return top_k_preds


model = MLP(Xtr.shape[1], args.hidden_dim, embs.shape[1])
model = model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
metric = Metric(label_set, son2parent)

re_calculate_count = 0

for epoch in range(args.epochs):

    model.train()
    tr_losses = []

    st = time()
    
    for i, (X, y, A) in enumerate(train_dataloader):

        X, y, A = X.cuda(), y.cuda(), A.cuda()

        Apred = model(X)

        loss = loss_fn(y, Apred, pair_wise_hyp, c = args.c, T = 1)

        while loss.isnan().any():
            loss = loss_fn(y, Apred, pair_wise_hyp, c = args.c, T = 1)
            re_calculate_count += 1

            if re_calculate_count % 100 == 0:
                print(f"re_calculate_count: {re_calculate_count}")

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        tr_losses.append(loss.item())
    epoch_time = time() - st

    # Get new feature
    if epoch % 30 == 0:
        with torch.no_grad():

            Apred_list = []
            te_losses = []

            for i, (X, y, A) in enumerate(test_dataloader):

                X, y, A = X.cuda(), y.cuda(), A.cuda()

                Apred = model(X)
                Apred_list.append(Apred.detach().cpu())

                loss = loss_fn(y, Apred, pair_wise_hyp, c = args.c, T = 1)

                while loss.isnan().any():
                    loss = loss_fn(y, Apred, pair_wise_hyp, c = args.c, T = 1)
                    re_calculate_count += 1
                    
                te_losses.append(loss.detach().cpu().item())

            Apred = torch.cat(Apred_list, axis = 0)

        st = time()
        ypred_topk = get_topK_preds(Apred, yte, top_K)
        mAP = metric.hop_mAP(ypred_topk, yte, hop = 0)
        SmAP = metric.hop_mAP(ypred_topk, yte, hop = 2)
        acc = (ypred_topk[:,0] == yte).float().mean().item()
        eval_time = time() - st

        print(f'Epoch: {epoch}, Train Loss: {np.mean(tr_losses):.4f}, Test Loss: \
            {np.mean(te_losses):.4f}, Epoch Time: {epoch_time:.4f}, Eval Time: \
            {eval_time:.4f}, Acc: {acc:.4f}, mAP: {mAP:.4f}, SmAP: {SmAP:.4f}')

        # log on tensorboard
        writer.add_scalar('train_loss', np.mean(tr_losses), epoch)
        writer.add_scalar('test_loss', np.mean(te_losses), epoch)
        writer.add_scalar('acc', acc, epoch)
        writer.add_scalar('mAP', mAP, epoch)
        writer.add_scalar('SmAP', SmAP, epoch)

        # add exponential moving average to acc and mAP on tensorboard
        if epoch == 0:
            ema_acc = acc
            ema_mAP = mAP
            ema_SmAP = SmAP
        else:
            ema_acc = 0.9 * ema_acc + 0.1 * acc
            ema_mAP = 0.9 * ema_mAP + 0.1 * mAP
            ema_SmAP = 0.9 * ema_SmAP + 0.1 * SmAP

        writer.add_scalar('ema_acc', ema_acc, epoch)
        writer.add_scalar('ema_mAP', ema_mAP, epoch)
        writer.add_scalar('ema_SmAP', ema_SmAP, epoch)

    # save model per 10 epoch since epoch 20
    if epoch >= 20 and epoch % 50 == 0:
        emb_dim_str = args.emb_path.split("/")[-1].split(".")[0].split("_")[-1]
        torch.save(model.state_dict(), f'runs/{args.dataset}_model_hypfloat_{emb_dim_str}_{epoch}.pth')