import torch
import torch.nn as nn
from src.meta_learning.model import MLP
from src.verify.trainer import ForwardTrainer
from src.verify.trainer.utils import DatasetShell
from src.verify.trainer.callbacks import EarlyStopping, LogWriter
from collections import defaultdict
import numpy as np
import sys

from src.verify.trainer.metrics import F1
from sklearn.metrics import recall_score, precision_score
from src.dataloaders import SelectedDataset


def precision(output, target):
    output = output.detach().cpu().numpy()
    target = target.detach().cpu().numpy()
    max_output_id = output.argmax(1)
    return precision_score(target, max_output_id, average='macro')


def recall(output, target):
    output = output.detach().cpu().numpy()
    target = target.detach().cpu().numpy()
    max_output_id = output.argmax(1)
    return recall_score(target, max_output_id, average='macro')


def verification(train_dataset, valid_dataset, test_dataset, hidden_layers, batch_size=128, max_epoch=2000, verbose=0, device=None, lr=3 * 1e-3):
    # dataset = DatasetShell(train_dataset)
    # train_set, eval_set = dataset.shuffle().cut(0.8)
    # test_set = test_dataset

    net = MLP(*hidden_layers, act_class=nn.LeakyReLU, dropout=0.)
    trainer = ForwardTrainer(net, nn.CrossEntropyLoss(),
                             optimizer=torch.optim.Adam(net.parameters(), lr=lr, weight_decay=0.0001),
                             callbacks=[
                                 EarlyStopping(patience=20, restore_best_weights=True),
                                 # Checkpoint('./log', save_weight_only=False),
                                 LogWriter(str(train_dataset), verbose=verbose),
                             ], epoch_metrics=[F1(average='macro'), precision, recall],
                             device=device,
                             )

    trainer.train(max_epoch, train_dataset, valid_dataset, batch_size=batch_size)
    if test_dataset is None:
        test_dataset = valid_dataset
    val_res = trainer.evaluate(valid_dataset, batch_size=batch_size)
    res = trainer.evaluate(test_dataset, batch_size=batch_size)
    res.update({'val_' + k: v for k, v in val_res.items()})
    return res


def pack_results(result_list):
    data = defaultdict(lambda: [])
    for result in result_list:
        for k, v in result.items():
            data[k].append(v)

    res = {k: np.mean(v) for k, v in data.items()}
    res.update({k + '_std': np.std(v) for k, v in data.items()})
    res.update({k + '_max': np.max(v) for k, v in data.items()})
    return res


def ntimes_verification(n_iter, *args, **kwargs):
    results = []
    for i in range(n_iter):
        result = verification(*args, **kwargs)
        print(f'{i + 1}/{n_iter}:', result)
        sys.stdout.flush()
        results.append(result)
    return pack_results(results)


def test_subset(subset, n_iter, train_dataset, valid_dataset, test_dataset, hidden_layers, batch_size=128,
                max_epoch=2000, verbose=0, device=None, lr=3 * 1e-3):
    subset = np.array(sorted(list(subset)))
    test_train_dataset = SelectedDataset(train_dataset, select_ids=subset)
    test_valid_dataset = SelectedDataset(valid_dataset, select_ids=subset)
    if test_dataset is not None:
        test_test_dataset = SelectedDataset(test_dataset, select_ids=subset)
    else:
        test_test_dataset = None
    return ntimes_verification(n_iter, test_train_dataset, test_valid_dataset, test_test_dataset,
                               hidden_layers, batch_size, max_epoch, verbose, device, lr=lr)
