import sys

sys.path.append("..")
import os

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler, MultiStepLR, StepLR
import torcheeg.transforms as transforms
from torcheeg.transforms import Compose, ToTensor, Resize, CWTSpectrum, BandSignal

from sklearn.metrics import roc_auc_score as ras
import numpy as np
from utils.save_results import save_output
from utils.utils import (get_loss,
                         get_dataset_statistics,
                         get_param_groups,
                         get_dataset_max_min,
                         get_transforms,
                         get_params_groups,
                         clip_gradients_value)


from hyperbolic_lib.lib.geoopt.optim import RiemannianSGD, RiemannianAdam, RiemannianAdamW

APPLY_MEAN = False
APPLY_MAX = False


def trainNetwork(net, trainloader, validloader, testloader, model_path=None, model=None, dataset=None, bs=64,
                 iterations=500, lr=5 * 1e-4, wd=None, repeat=None, sub=None, epochs=None, subject_weights=None,
                 verbose=True, save_model=False, save_results=True, early_stopping=False, grace_period=20,
                 hyperbolic=False, clip_grad=0, loss_1=None, loss_2=None, num_classes=0, **kwargs):
    best_test = 0
    best_val = 0
    test_from_best_val = 0
    update_best_test = 0

    device = kwargs['device']
    CE = nn.CrossEntropyLoss(label_smoothing=0.0)
    alpha = [1.0] * num_classes
    hyperbolic = True

    if not hyperbolic:
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)

    else:
        print("hyperbolic")
        param_groups = get_params_groups(net, loss_1, loss_2,
                                         weight_decay=wd) if loss_1 is not None else get_param_groups(net, lr, wd)

        optimizer = RiemannianAdam(param_groups, lr=lr, weight_decay=wd)

    scheduler = StepLR(optimizer, step_size=100, gamma=0.2)

    bestLoss = 1e10
    bestLoss_early_stopping = 1e10

    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    test_accs = []

    no_improvement_count = 0

    for ite in range(iterations):
        net.train()
        acc_val = 0
        acc_tr = 0
        tr_len = 0
        val_len = 0

############################################# train ####################################################################

        for row in trainloader:
            x_orig, yb = row
            tr_len += yb.shape[0]

            x_orig = x_orig.to(device)
            yb = yb.to(device)

            out = net(x_orig)

            loss = CE(out, yb)

            optimizer.zero_grad()

            loss.backward(retain_graph=True)

            if clip_grad > 0:
                param_norms = clip_gradients_value(net, clip_grad, losses=[loss])

            optimizer.step()
            acc_tr += (torch.max(out, 1).indices == yb).sum().item()

            torch.cuda.empty_cache()
        # scheduler.step()

########################################## validate ####################################################################
        net.eval()
        TL = 0

        for row in validloader:
            x_orig, yb = row

            x_orig = x_orig.to(device)
            yb = yb.to(device)

            val_len += yb.shape[0]

            with torch.no_grad():
                out = net(x_orig)
                acc_val += (torch.max(out, 1).indices == yb).sum().item()

                TL += CE(out,yb)

########################################## logging and checks ##########################################################
        if early_stopping:
            if TL < bestLoss_early_stopping:
                bestLoss_early_stopping = TL
                no_improvement_count = 0
            else:
                no_improvement_count += 1

            if no_improvement_count >= grace_period:
                if verbose:
                    print(f'Early stopping at iteration {ite} as no improvement has been observed in {grace_period} '
                          f'iterations for the validation loss.')
                final_path = os.path.join(model_path, f'repeat{repeat}_sub{sub}_epochs{epochs}_lr{lr}_wd{wd}.pt')
                torch.save(net, final_path)
                break

        if verbose:
            print('')
            print(f'Iteration{ite}=====')
            print(f'train_loss:{loss:.4f}    val_loss:{TL / len(validloader):.4f}')
            print(f'train_acc:{acc_tr / tr_len:.4f}    val_acc:{acc_val / val_len:.4f}')

        if acc_val / val_len > best_val:
            best_val = acc_val / val_len
            update_best_test = 1

        train_losses.append(loss.detach().cpu())
        train_accs.append(acc_tr / tr_len)
        val_losses.append(TL / len(validloader))
        val_accs.append(acc_val / val_len)

############################################## test ####################################################################
        if dataset.startswith('bcicha'):
            test_acc = testNetwork_auc(net, testloader, device)
        else:
            test_acc = testNetwork(net, testloader, device)
        test_accs.append(test_acc)

        test_from_best_acc = test_acc if update_best_test == 1 else test_from_best_val
        update_best_test = 0

        best_test = test_acc if test_acc > best_test else best_test
        if verbose:
            print(f'test_acc:{test_acc.__round__(3)}')
            print(f"best_test:{best_test.__round__(3)}")
            print(f"best_loss:{bestLoss.__round__(3)}")
            print(f"test_from_best_val:{test_from_best_val.__round__(3)}")
            print(f"test_from_best_acc:{test_from_best_acc.__round__(3)}")

        if TL / len(validloader) < bestLoss:
            bestLoss = (TL / len(validloader)).item()
            test_from_best_val = test_acc

    if save_results:
        save_output(model_name=model, dataset=dataset, sub=sub, bs=bs, lr=lr, wd=wd, alpha=subject_weights,
                    epochs=epochs,
                    results=[
                        np.array([loss.cpu() if isinstance(loss, torch.Tensor) else loss for loss in train_losses]),
                        np.array([acc.cpu() if isinstance(acc, torch.Tensor) else acc for acc in train_accs]),
                        np.array([loss.cpu() if isinstance(loss, torch.Tensor) else loss for loss in val_losses]),
                        np.array([acc.cpu() if isinstance(acc, torch.Tensor) else acc for acc in val_accs]),
                        np.array([acc.cpu() if isinstance(acc, torch.Tensor) else acc for acc in test_accs])],
                    **kwargs)

    return TL/len(validloader), acc_val/val_len


def testNetwork(net, testloader, device ):
    net.eval()
    acc = 0
    test_len = 0

    for row in testloader:
        x_orig, yb = row

        x_orig = x_orig.to(device)
        yb = yb.to(device)

        test_len += yb.shape[0]
        with torch.no_grad():
            pred = net(x_orig)
            acc += (torch.max(pred, 1).indices == yb).sum().item()

    return acc / test_len


def testNetwork_auc(net, testloader, device):
    net.eval()
    acc = 0
    softmax = nn.Softmax(dim=1)
    y_pred = torch.empty(0, device=device)
    y_true = torch.empty(0, device=device)

    for row in testloader:
        with torch.no_grad():
            x_orig, yb = row

            x_orig = x_orig.to(device)
            yb = yb.to(device)

            pred = net(x_orig)
            y_pred = torch.cat((y_pred, pred[:, 1]), 0)
            y_true = torch.cat((y_true, yb), 0)

    return ras(y_true.detach().cpu().numpy(),
               y_pred.detach().cpu().numpy())
