import os

import numpy as np
import torch
from torch import optim

from . import data, utils, models


def update(model, device, loader, loss_func, optimizer):
    model.train()
    for x, y in loader:
        x = x.to(device).view(x.size(0), -1)
        y = y.to(device)
        optimizer.zero_grad()
        loss = loss_func(model.forward(x), y)
        loss.backward()
        optimizer.step()


def evaluate(model, device, loader, loss_func):
    model.eval()
    loss_sum, acc_sum, num_data = 0, 0, 0
    for x, y in loader:
        x = x.to(device).view(x.size(0), -1)
        y = y.to(device)
        y_pred = model.forward(x)
        loss_sum += loss_func(y_pred, y).item() * x.size(0)
        acc_sum += torch.eq(torch.argmax(y_pred, dim=1), y).sum().item()
        num_data += x.size(0)
    return loss_sum / num_data, acc_sum / num_data


def visualize(model, device, loader):
    from sklearn.decomposition import PCA
    from matplotlib import pyplot as plt

    model.eval()
    for batch_x, batch_y in loader:
        batch_x = batch_x.to(device).view(batch_x.size(0), -1)
        pi_list, explanations = model.model.explain_decisions(batch_x)
        out = PCA().fit_transform(batch_x.numpy())

        plt.rc('figure', figsize=(3, 3))
        colors = np.array(['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6'])

        plt.scatter(out[:, 0], out[:, 1], c=colors[batch_y])
        plt.savefig('../out/figures/out-y.png', bbox_inches='tight', dpi=300)
        plt.close()

        for i in range(len(pi_list) - 1):
            plt.scatter(out[:, 0], out[:, 1], c=colors[pi_list[i].argmax(1) + 3])
            plt.savefig('../out/figures/out-{}.png'.format(i), bbox_inches='tight', dpi=300)
            plt.close()

        plt.scatter(out[:, 0], out[:, 1], c=colors[pi_list[-1].argmax(1)])
        plt.savefig('../out/figures/out-last.png', bbox_inches='tight', dpi=300)
        plt.close()


def explain(model, device, loader):
    from matplotlib import pyplot as plt

    model.eval()
    labels, exp_list = [], []
    for batch_x, batch_y in loader:
        batch_x = batch_x.to(device).view(batch_x.size(0), -1)
        for i in range(batch_x.size(0)):
            pi_list, explanations = model.model.explain_decisions(batch_x[i].unsqueeze(0))

            pi_curr = torch.stack([pi[0] for pi in pi_list[:-1]])
            plt.imshow(pi_curr, vmin=0, vmax=1)
            plt.savefig('../out/figures/pi-{}.png'.format(i + 1), dpi=300, bbox_inches='tight')

            fi_curr = torch.stack([e[0] for e in explanations])
            plt.imshow(fi_curr, vmin=0, vmax=1)
            plt.savefig('../out/figures/fi-{}.png'.format(i + 1), dpi=300, bbox_inches='tight')

            print('Data:')
            print('\t'.join('{:.4f}'.format(e) for e in batch_x[i]))
            print()
            print('Decision probabilities:')
            for pi in pi_list:
                print('\t'.join('{:.4f}'.format(e) for e in pi[0]))
            print()
            print('Feature importances:')
            for exp in explanations:
                print('\t'.join('{:.4f}'.format(e) for e in exp[0]))
            print()
    return labels, exp_list


def train(args):
    dataset = args.data
    seed = args.seed

    device = utils.to_device(args.gpu)
    log_out = os.path.join(args.out_path, 'logs-{}/{}.tsv'.format(seed, dataset))
    os.makedirs(os.path.dirname(log_out), exist_ok=True)

    data_dict = data.read_as_dict(args.data_path, dataset, args.batch_size)
    n_data = data_dict['nd']
    in_features = data_dict['nx']
    num_classes = data_dict['ny']
    trn_loader = data_dict['trn_loader']
    test_loader = data_dict['test_loader']

    utils.set_seed(seed)

    model = models.to_model(args, dataset, in_features, num_classes, device)
    loss_func = models.to_loss(args.model)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    with open(log_out, 'w') as f:
        f.write('epoch\ttrn_loss\ttrn_acc\tis_best\n')

    best_loss, best_epoch = np.inf, 0
    for epoch in range(args.epochs + 1):
        update(model, device, trn_loader, loss_func, optimizer)
        trn_loss, trn_acc = evaluate(model, device, trn_loader, loss_func)

        if trn_loss < best_loss:
            best_loss = trn_loss
            best_epoch = epoch

        with open(log_out, 'a') as f:
            f.write(f'{epoch:5d}\t{trn_loss:.4f}\t{trn_acc:.4f}\t')
            if epoch == best_epoch:
                f.write('\tBEST')
            f.write('\n')

    _, trn_acc = evaluate(model, device, trn_loader, loss_func)
    _, test_acc = evaluate(model, device, test_loader, loss_func)

    visualize(model, device, trn_loader)

    if args.explain:
        exp_out = os.path.join(args.out_path, 'explanations-{}/{}.tsv'.format(seed, dataset))
        os.makedirs(os.path.dirname(exp_out), exist_ok=True)
        labels, exp_list = explain(model, device, test_loader)
        with open(exp_out, 'w') as f:
            f.write('example\tlabel\timportances\n')
            for i, (l, exp) in enumerate(zip(labels, exp_list)):
                f.write('{}\t{}\t{}\n'.format(i, l, '\t'.join(str(e.item()) for e in exp)))

    if args.save:
        model_out = os.path.join(args.out_path, 'models-{}/{}.pth'.format(seed, dataset))
        os.makedirs(os.path.dirname(model_out), exist_ok=True)
        model.save(model_out)

    return dict(args=args, n_data=n_data, values=(best_epoch, trn_acc, test_acc))
