from functools import partial

import numpy as np
import torch

# import read_data
from gnn.Datasets import GraphDatasets
from networks import AutoVisualGAT


def evaluate(d_loader):
    with torch.no_grad():
        mses = []
        maes = []
        errors = []
        for cdist, n_view_graphs, z in d_loader:
            optimizer.zero_grad()
            # inp = inp.to(torch.float32).to(device)
            z = torch.from_numpy(np.concatenate(z)).to(torch.float32).to(device)

            z_hat = net(cdist, n_view_graphs)
            # y_hat = y_hat.view(-1)

            relative_error = 1 - (torch.norm(z_hat - z, p=1, dim=1) / torch.norm(z, p=1, dim=1))
            mses.append(torch.sum((z_hat - z)**2, dim=1))
            maes.append(torch.sum(torch.abs(z_hat - z), dim=1))
            errors.append(relative_error)

    mse = torch.mean(torch.cat(mses))
    mae = torch.mean(torch.cat(maes))
    errors = torch.mean(torch.cat(errors))
    return mse, mae, errors


if __name__ == '__main__':
    device = 'cuda:0'

    # train_names = ['arrhythmia', 'wine', 'lympho', 'glass', 'vertebral',
    #                                    'wbc', 'ecoli', 'ionosphere', 'breastw', 'pima',
    #                                    'vowels',        'letter']
    # test_names = ['cardio',  'seismic', 'musk', 'speech', 'abalone']

    train_names = ['mnist_group2']
    test_names = ['mnist_group1']

    train_ds = GraphDatasets.DatasetGraphDataset(data_names=train_names, cdist_path='./prepare_data/clip/features',
                                                 visual_path='./prepare_data/bo/res-2', normalize_z=True,
                                                 z_mu=None, z_std=None)

    test_ds = GraphDatasets.DatasetGraphDataset(data_names=test_names, cdist_path='./prepare_data/clip/features',
                                                visual_path='./prepare_data/bo/res-2', normalize_z=True,
                                                z_mu=train_ds.z_mu, z_std=train_ds.z_std)

    get_loader = partial(torch.utils.data.DataLoader, batch_size=1,
                                               shuffle=True,
                                               num_workers=0,
                         collate_fn=lambda x: list(zip(*x))
                                               )

    train_loader = get_loader(train_ds)
    test_loader = get_loader(test_ds)

    input_dim = train_ds[0][0].shape[0]


    net = AutoVisualGAT.AutoVisualNet(input_dim=input_dim, gnn_hidden=512, gnn_out=64, out_dim=2,
                                   n_graph_view=5, device=device)
    net = net.to(device)

    n_epoch = 5000

    # optimizer = torch.optim.Adam(params=net.parameters(), lr=1e-3, amsgrad=True)
    optimizer = torch.optim.AdamW(params=net.parameters(), lr=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epoch)
    loss_fn = torch.nn.MSELoss()

    records = {}
    for epoch in range(n_epoch):
        running_loss = 0
        for cdist, n_view_graphs, z in train_loader:

            optimizer.zero_grad()
            # inp = inp.to(torch.float32).to(device)
            z = torch.from_numpy(np.concatenate(z)).to(torch.float32).to(device)

            z_hat = net(cdist, n_view_graphs)
            loss = loss_fn(z_hat, z)
            # loss = net.loss_fn_calib2(inp, y)
            loss.backward()
            optimizer.step()
            running_loss += loss

        scheduler.step()

        if (epoch+1) % 50 == 0:
            train_mse, train_mae, train_error = evaluate(train_loader)
            test_mse, test_mae, test_error = evaluate(test_loader)

            records['train_mse'] = records.get('train_mse', []) + [train_mse.item()]
            records['train_mae'] = records.get('train_mae', []) + [train_mae.item()]
            records['train_error'] = records.get('train_error', []) + [train_error.item()]
            # records['val_mse'] = records.get('val_mse', []) + [val_mse.item()]
            # records['val_mae'] = records.get('val_mae', []) + [val_mae.item()]
            # records['val_error'] = records.get('val_error', []) + [val_error.item()]
            records['test_mse'] = records.get('test_mse', []) + [test_mse.item()]
            records['test_mae'] = records.get('test_mae', []) + [test_mae.item()]
            records['test_error'] = records.get('test_error', []) + [test_error.item()]
            records['running_loss'] = records.get('running_loss', []) + [(running_loss/len(train_loader)).item()]

            print([(k, v[-1]) for k, v in records.items()])

    torch.save((net, records), f'./res/GNN-gin-run4-{n_epoch}epoch.tar')
