import os
import copy
import argparse
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
import random
from class_encoder.gcn import GCN
from class_encoder.gat import GAT
from class_encoder.rgcn import RGCN
from class_encoder.lstm import LSTM
from class_encoder.transformer import TransformerGCN
from class_encoder.dgp import DGP
from class_encoder.sgcn import SGCN
from class_encoder.gcnz import GCNZ

from torchvision.models import resnet101

DIR_PATH = os.path.dirname(os.path.realpath(__file__))

def mean(options):
    save_path = os.path.join(DIR_PATH, 'save/gcn')

    save_path += '_seed_' + str(options['seed'])

    adj_lists, init_feats = utils.setup_graph(options['imagenet_graph_path'],
                                              collapse=False)
    model = GCN(init_feats, adj_lists, options['device'],
                gcn=True, sample=True,
                options=options)

    return model, save_path

def gat(options):
    save_path = os.path.join(DIR_PATH, 'save/gat')
    save_path += '_seed_' + str(options['seed'])

    adj_lists, init_feats = utils.setup_graph(options['imagenet_graph_path'],
                                              collapse=False)
    model = GAT(init_feats, adj_lists, options['device'],
                gcn=True, sample=True,
                options=options)

    return model, save_path

def rgcn(options):
    save_path = os.path.join(DIR_PATH, 'save/rgcn')
    save_path += '_seed_' + str(options['seed'])

    adj_lists, init_feats = utils.setup_graph(options['imagenet_graph_path'],
                                              collapse=False)
    model = RGCN(init_feats, adj_lists, options['device'],
                gcn=True, sample=True,
                options=options)

    return model, save_path


def lstm(options):
    save_path = os.path.join(DIR_PATH, 'save/lstm')
    save_path += '_seed_' + str(args.seed)
    print('loading from {}'.format(save_path))

    adj_lists, init_feats = utils.setup_graph(options['imagenet_graph_path'],
                                              collapse=False)
    model = LSTM(init_feats, adj_lists, options['device'],
                gcn=False, sample=True,
                options=options)


    return model, save_path


def transformer(options):
    save_path = os.path.join(DIR_PATH, 'save/transformer')
    save_path += '_gcn'

    if options['maxpool']:
        save_path += "_maxpool"

    save_path += '_n1_' + str(options['n1'])
    save_path += '_n2_' + str(options['n2'])

    save_path += '_heads_' + str(options['num_heads'])
    save_path += '_pd1_' + str(options['pd1'])
    save_path += '_pd2_' + str(options['pd2'])
    save_path += '_fh1_' + str(options['fh1'])
    save_path += '_fh2_' + str(options['fh2'])
    save_path += '_dp_' + str(options['dp'])
    save_path += '_seed_' + str(options['seed'])

    adj_lists, init_feats = utils.setup_graph(options['imagenet_graph_path'],
                                              collapse=False)

    model = TransformerGCN(init_feats, adj_lists, device,
                           gcn=True, sample=True,
                           options=options)

    return model, save_path


def dgp(device):
    graph = json.load(open(os.path.join(DIR_PATH, 'data/dense_graph.json'), 'r'))
    wnids = graph['wnids']
    n = len(wnids)

    edges_set = graph['edges_set']
    print('edges_set', [len(l) for l in edges_set])

    lim = 4
    for i in range(lim + 1, len(edges_set)):
        edges_set[lim].extend(edges_set[i])
    edges_set = edges_set[:lim + 1]
    print('edges_set', [len(l) for l in edges_set])

    hidden_layers = 'd2048,d'
    gcn = DGP(n, edges_set,
                        300, 2049, hidden_layers)

    save_path = os.path.join(DIR_PATH, 'save/dgp_seed_{}'.format(options['seed']))
    return gcn, save_path

def sgcn(device):
    graph = json.load(open(os.path.join(DIR_PATH, 'data/induced_graph.json'), 'r'))
    wnids = graph['wnids']
    n = len(wnids)
    edges = graph['edges']

    edges = edges + [(v, u) for (u, v) in edges]
    edges = edges + [(u, u) for u in range(n)]

    hidden_layers = 'd2048,d'
    gcn = SGCN(n, edges, 300, 2049, hidden_layers, device)

    save_path = os.path.join(DIR_PATH, 'save/sgcn_seed_{}'.format(options['seed']))

    return gcn, save_path

def gcnz(device):
    graph = json.load(open(os.path.join(DIR_PATH, 'data/induced_graph.json'), 'r'))
    wnids = graph['wnids']
    n = len(wnids)
    edges = graph['edges']

    edges = edges + [(v, u) for (u, v) in edges]
    edges = edges + [(u, u) for u in range(n)]

    hidden_layers = 'd2048,d'
    gcn = GCNZ(n, edges, 300, 2049, hidden_layers, device)

    save_path = os.path.join(DIR_PATH, 'save/resnet_gcnz_seed_{}'.format(options['seed']))

    return gcn, save_path


def predict(model, graph_path, dataset):
    print('generating graph embeddings for {}'.format(dataset))
    # load the model and get predictions
    # awa_graph_path = os.path.join(DIR_PATH, 'data/subgraphs/apy_graph')

    adj_lists, init_feats = utils.setup_graph(graph_path, collapse=False)
    mapping_path = os.path.join(graph_path, 'mapping.json')
    mapping = json.load(open(mapping_path))
    concept_idx = torch.tensor([mapping[str(idx)] for idx in range(len(mapping))]).to(options['device'])

    # replace the adj list and init feats
    init_feat =  nn.Embedding.from_pretrained(init_feats, freeze=True).to(options['device'])
    model.enc2.adj_lists = adj_lists
    model.enc1.adj_lists = adj_lists

    model.agg1.features = init_feat
    model.enc1.features = init_feat

    model.eval()
    pred_vectors = model(concept_idx)

    return pred_vectors


def get_label_encoder(label_encoder_type, options):
    if label_encoder_type == 'gcn':
        return mean(options)

    if label_encoder_type == 'gat':
        return gat(options)

    if label_encoder_type == 'rgcn':
        return rgcn(options)

    if label_encoder_type == 'lstm':
        return lstm(options)

    if label_encoder_type == 'transformer':
        return transformer(options)

    if label_encoder_type == 'sgcn':
        return sgcn(options['device'])

    if label_encoder_type == 'gcnz':
        return gcnz(options['device'])

    if label_encoder_type == 'dgp':
        return dgp(options['device'])


def get_apy_preds(pred_obj):

    pred_wnids = pred_obj['wnids']
    pred_vectors = pred_obj['pred'].cpu()
    pred_dic = dict(zip(pred_wnids, pred_vectors))
    with open(os.path.join(DIR_PATH, 'materials/apy_wnid.json')) as fp:
        apy_wnid = json.load(fp)

    train_wnids = ['0'] * 20
    test_wnids = apy_wnid['wnid']

    pred_vectors = utils.pick_vectors(pred_dic, train_wnids + test_wnids, is_tensor=True)

    return pred_vectors

def get_awa_preds(pred_obj):
    awa2_split = json.load(open(os.path.join(DIR_PATH, 'materials/awa2-split.json'), 'r'))
    train_wnids = awa2_split['train']
    test_wnids = awa2_split['test']

    pred_wnids = pred_obj['wnids']
    pred_vectors = pred_obj['pred'].cpu()
    pred_dic = dict(zip(pred_wnids, pred_vectors))

    pred_vectors = utils.pick_vectors(pred_dic, train_wnids + test_wnids, is_tensor=True)
    return pred_vectors


def train_baseline_model(model,fc_vectors, device):
    model.to(device)

    graph = json.load(open(os.path.join(DIR_PATH, 'data/dense_graph.json'), 'r'))
    wnids = graph['wnids']
    word_vectors = torch.tensor(graph['vectors']).to(device)
    word_vectors = F.normalize(word_vectors)

    print('word vectors:', word_vectors.shape)
    print('fc vectors:', fc_vectors.shape)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001,
                        weight_decay=0.0005)

    v_train, v_val = 0.95, 0.05
    n_trainval = len(fc_vectors)
    n_train = round(n_trainval * (v_train / (v_train + v_val)))
    print('num train: {}, num val: {}'.format(n_train, n_trainval - n_train))

    tlist = list(range(len(fc_vectors)))
    random.shuffle(tlist)

    trlog = {}
    trlog['train_loss'] = []
    trlog['val_loss'] = []
    trlog['min_loss'] = 0
    best_model = None

    for epoch in range(1, args.max_epoch + 1):
        model.train()
        output_vectors = model(word_vectors)
        loss = utils.mask_l2_loss(output_vectors, fc_vectors, tlist[:n_train])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        model.eval()
        output_vectors = model(word_vectors)
        train_loss = utils.mask_l2_loss(output_vectors, fc_vectors, tlist[:n_train]).item()
        if v_val > 0:
            val_loss = utils.mask_l2_loss(output_vectors, fc_vectors, tlist[n_train:]).item()
            loss = val_loss
        else:
            val_loss = 0
            loss = train_loss
        print('epoch {}, train_loss={:.4f}, val_loss={:.4f}'
              .format(epoch, train_loss, val_loss))

        pred_obj = {
            'wnids': wnids,
            'pred': output_vectors
        }

        if trlog['val_loss']:
            min_val_loss = min(trlog['val_loss'])
            if val_loss < min_val_loss:
                best_model = copy.deepcopy(model.state_dict())
        else:
            best_model = copy.deepcopy(model.state_dict())

        trlog['train_loss'].append(train_loss)
        trlog['val_loss'].append(val_loss)

    model.load_state_dict(best_model)
    return model, pred_obj, trlog


def train_gnn_model(model,fc_vectors, device, options):
    mapping_path = os.path.join(options['imagenet_graph_path'], 'mapping.json')
    mapping = json.load(open(mapping_path))
    # 1000 because we are training on imagenet 1000
    imagenet_idx = torch.tensor([mapping[str(idx)] for idx in range(1000)]).to(device)

    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001,
                        weight_decay=0.0005)

    v_train, v_val = 0.95, 0.05
    n_trainval = len(fc_vectors)
    n_train = round(n_trainval * (v_train / (v_train + v_val)))
    print('num train: {}, num val: {}'.format(n_train, n_trainval - n_train))

    tlist = list(range(len(fc_vectors)))
    random.shuffle(tlist)

    trlog = {}
    trlog['train_loss'] = []
    trlog['val_loss'] = []
    trlog['min_loss'] = 0
    num_w = fc_vectors.shape[0]
    best_model = None

    for epoch in range(1, options['num_epochs'] + 1):
        model.train()
        for i, start in enumerate(range(0, n_train, 100)):
            end = min(start + 100, n_train)
            indices = tlist[start:end]
            output_vectors = model(imagenet_idx[indices])
            loss = utils.l2_loss(output_vectors, fc_vectors[indices])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        output_vectors = torch.empty(num_w, 2049, device=device)
        with torch.no_grad():
            for start in range(0, num_w, 100):
                end = min(start + 100, num_w)
                output_vectors[start: end] = model(imagenet_idx[start: end])

        train_loss = utils.mask_l2_loss(output_vectors, fc_vectors, tlist[:n_train]).item()
        if v_val > 0:
            val_loss = utils.mask_l2_loss(output_vectors, fc_vectors, tlist[n_train:]).item()
            loss = val_loss
        else:
            val_loss = 0
            loss = train_loss

        print('epoch {}, train_loss={:.4f}, val_loss={:.4f}'
            .format(epoch, train_loss, val_loss))

        # check if I need to save the model
        if trlog['val_loss']:
            min_val_loss = min(trlog['val_loss'])
            if val_loss < min_val_loss:
                best_model = copy.deepcopy(model.state_dict())
        else:
            best_model = copy.deepcopy(model.state_dict())

        trlog['train_loss'].append(train_loss)
        trlog['val_loss'].append(val_loss)

    model.load_state_dict(best_model)
    return model, trlog


def get_fc():
    resnet = resnet101(pretrained=True)
    with torch.no_grad():
        b = resnet.fc.bias.detach()
        w = resnet.fc.weight.detach()
        fc_vectors = torch.cat((w, b.unsqueeze(-1)), dim=1)
    return F.normalize(fc_vectors)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--label_encoder', help='label encoder')
    parser.add_argument('--max-epoch', type=int, default=1000)
    parser.add_argument('--gpu', default='0')
    parser.add_argument('--n1', default=50, type=int)
    parser.add_argument('--n2', default=100, type=int)
    parser.add_argument('--gcn', action='store_true')
    parser.add_argument('--no-of-heads', default=1, type=int)
    parser.add_argument('--num_layers', default=1, type=int)
    parser.add_argument('--pd1', default=150, type=int)
    parser.add_argument('--pd2', default=1024, type=int)
    parser.add_argument('--hd1', default=300, type=int)
    parser.add_argument('--hd2', default=2048, type=int)
    parser.add_argument('--fh1', default=150, type=int)
    parser.add_argument('--fh2', default=1024, type=int)
    parser.add_argument('--maxpool', action='store_true')
    parser.add_argument('--dropout', default=0.1, type=float)
    parser.add_argument('--seed', default=0, type=int)

    args = parser.parse_args()

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    print('device : ', device)

    utils.set_seed(int(args.seed))

    imagenet_graph_path = os.path.join(DIR_PATH, 'data/subgraphs/imagenet_graph')
    apy_graph = os.path.join(DIR_PATH, 'data/subgraphs/apy_graph')
    awa2_graph = os.path.join(DIR_PATH, 'data/subgraphs/awa2_graph')

    options = {
        'label_encoder_type': args.label_encoder,
        'num_epochs': args.max_epoch,
        'device': device,
        'n1': args.n1,
        'n2': args.n2,
        'pd1': args.pd1,
        'pd2': args.pd2,
        'hd1': args.hd1,
        'hd2': args.hd2,
        'fh1': args.fh1,
        'fh2': args.fh2,
        'num_heads': args.no_of_heads,
        'num_layers': args.num_layers,
        'maxpool': args.maxpool,
        'dp': args.dropout,
        'gcn': args.gcn,
        'seed': args.seed,
        'imagenet_graph_path': imagenet_graph_path,
        'apy_graph_path': apy_graph,
        'awa2_graph_path': awa2_graph
    }

    model, save_path = get_label_encoder(options['label_encoder_type'], options)
    model = model.to(device)

    fc_vectors = get_fc()
    fc_vectors = fc_vectors.to(device)

    if args.label_encoder not in ['sgcn', 'dgp', 'gcnz']:
        model, tr_log = train_gnn_model(model, fc_vectors, device, options)

        all_preds = {}
        graph_paths = [awa2_graph, apy_graph]
        with torch.no_grad():
            for i, dataset in enumerate(['awa', 'apy']):
                preds = predict(model, graph_paths[i], dataset)
                all_preds[dataset] = preds
    else:
        model, pred_obj, tr_log = train_baseline_model(model, fc_vectors, device)
        all_preds = {}
        with torch.no_grad():
            all_preds['awa'] = get_awa_preds(pred_obj)
            all_preds['apy'] = get_apy_preds(pred_obj)

    torch.save(tr_log, save_path+"_loss.json")
    torch.save(model.state_dict(), save_path + '.pt')
    torch.save(all_preds, save_path + '.pred')


