import argparse
import pickle
from torch.utils.data import TensorDataset, DataLoader
import torch
import torch.nn as nn

import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))

from deepnorms import *
from models import FICNN, BregmanDivergence, GSBregmanDivergence
from models import MaxAffineNet, MaxAffineDivergence
from utils.logging import TrainLogger

DEVICE = 'cuda'
RENORM_SCALE = 50


parser = argparse.ArgumentParser()
parser.add_argument('--dataset')
parser.add_argument('--metric')
parser.add_argument('--train_size', type=int, default=50000)
parser.add_argument('--bs', type=int, default=512)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()


def convert_XYints_to_XYembs(X, Y, emb_dict):
    X_ = []
    for x in X:
        X_.append(emb_dict[x])
    X = np.array(X_)

    Y_ = []
    for y in Y:
        Y_.append(emb_dict[y])
    Y = np.array(Y_)

    return X, Y


def load_dataset(data_pickle='sf_1M.pickle',
                 emb_dict_pickle='sf_emb_dict.pickle',
                 train_size=50000,
                 test_size=30000,
                 seed=0):
    with open(data_pickle, 'rb') as f:
        X, Y, D = pickle.load(f)
    with open(emb_dict_pickle, 'rb') as f:
        embs = pickle.load(f)
    X, Y = convert_XYints_to_XYembs(X, Y, embs)

    assert train_size + test_size <= len(X)

    np.random.seed(seed)
    training_indices = np.random.choice(150000-test_size, train_size, replace=False)

    xtr = X[training_indices]
    ytr = Y[training_indices]
    dtr = D[training_indices]

    xte = X[-test_size:]
    yte = Y[-test_size:]
    dte = D[-test_size:]

    print('DATASET STATS: ')
    print(dte.mean(), dte.var(), dte.min(), dte.max())
    dtr = dtr / RENORM_SCALE
    dte = dte / RENORM_SCALE
    print('renormalized: ')
    print(dte.mean(), dte.var(), dte.min(), dte.max())


    xtr, ytr, dtr, xte, yte, dte = \
        map(lambda x: torch.tensor(x).float(), (xtr, ytr, dtr, xte, yte, dte))

    return TensorDataset(xtr, ytr, dtr), TensorDataset(xte, yte, dte)


class MetricModel(nn.Module):
    def __init__(self, metric, input_dim=160, pre_embed=4):
        super().__init__()
        self.metric = metric

        modules = [nn.Linear(input_dim, 128)]
        for _ in range(pre_embed - 1):
            modules.append(nn.ReLU())
            modules.append(nn.Linear(128, 128))
        self.embedding = nn.Sequential(*modules)
    
    def forward(self, x, y):
        phi_x = self.embedding(x)
        phi_y = self.embedding(y)
        return self.metric(phi_x, phi_y)


def train_metric(metric, train_loader, optimizer):
    metric.train()
    loss_fn = nn.MSELoss()

    for x1, x2, target in train_loader:
        x1, x2, target = x1.to(DEVICE), x2.to(DEVICE), target.to(DEVICE)
        out = metric(x1, x2)
        loss = loss_fn(out, target)

        opt.zero_grad()
        loss.backward()

        opt.step()



def eval_metric(metric, data_loader):
    metric.eval()
    preds = []
    truth = []
    for x1, x2, target in data_loader:
        x1, x2, target = x1.to(DEVICE), x2.to(DEVICE), target.to(DEVICE)
        out = metric(x1, x2)
        preds.append(out.detach())
        truth.append(target)

    preds = torch.cat(preds)
    truth = torch.cat(truth)

    # undo normalization to score on orig. scale
    preds = preds * RENORM_SCALE
    truth = truth * RENORM_SCALE

    return torch.mean((preds - truth) ** 2).item()


if __name__ == '__main__':

    log = TrainLogger(
        log_dir='./results/graph_dist',
        name=f'graph_{args.dataset}_metric_{args.metric}_{args.train_size}_{args.seed}'
    )

    train_ds, test_ds = load_dataset(
        f'./graph_distance/data/{args.dataset}_150k.pickle',
        f'./graph_distance/data/{args.dataset}_lm_32n0.2-96_emb_dict.pickle',
        train_size=args.train_size,
        test_size=30000,
        seed=args.seed
    )

    train_loader = DataLoader(train_ds, batch_size=args.bs, num_workers=4, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=args.bs, num_workers=4, shuffle=False)

    input_dim = 128
    if args.dataset in ('3dd', 'octagon', 'traffic'):
        input_dim = 160

    if args.metric == 'mahalanobis':
        metric = MetricModel(
            MahalanobisMetric(128, 128),
            input_dim=input_dim
        ).to(DEVICE)

    elif args.metric == 'widenorm':
        metric = MetricModel(
            WideNormMetric(128, 32, 32, 5, mode='maxavg'),
            input_dim=input_dim
        ).to(DEVICE)
    
    elif args.metric == 'deepnorm':
        metric = MetricModel(
            DeepNormMetric(128, [128, 128, 128],
                activation=lambda: MaxReLUPairwiseActivation(128),
                concave_activation_size=5,
                mode='maxavg'),
            input_dim=input_dim
        ).to(DEVICE)

    elif args.metric == 'maxaffine':
        metric = MetricModel(
            MaxAffineDivergence(MaxAffineNet(128, K=50)),
            input_dim=input_dim
        ).to(DEVICE)

    elif args.metric == 'bregman':
        metric = MetricModel(
            BregmanDivergence(FICNN(128, 1, 2, 128, activation='softplus')),
            input_dim=input_dim
        ).to(DEVICE)

    elif args.metric == 'sqrtbreg':
        metric = MetricModel(
            BregmanDivergence(FICNN(128, 1, 2, 128, activation='softplus'), take_sqrt=True),
            input_dim=input_dim
        ).to(DEVICE)

    elif args.metric == 'gsb':
        metric = MetricModel(
            GSBregmanDivergence(FICNN(128, 1, 2, 128, activation='softplus')),
            input_dim=input_dim
        ).to(DEVICE)


    opt = torch.optim.Adam(metric.parameters(), lr=5e-4)

    for epoch in range(100):

        if epoch == 50:
            opt.param_groups[0]['lr'] /= 10

        train_metric(metric, train_loader, opt)

        tr_loss = eval_metric(metric, train_loader)
        te_loss = eval_metric(metric, test_loader)

        if epoch % 5 == 0:
            print('epoch ', epoch)
            print(tr_loss, te_loss)

        log.add('epoch', epoch)
        log.add('train_loss', tr_loss)
        log.add('test_loss', te_loss)
        log.export()
