import sys

from reconstruction.losses import CutFillConfig, apply_cut_and_fill

sys.path.append("..")
import os

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler, MultiStepLR, StepLR
import torcheeg.transforms as transforms
from torcheeg.transforms import Compose, ToTensor, Resize, CWTSpectrum, BandSignal

from sklearn.metrics import roc_auc_score as ras
import numpy as np
from utils.save_results import save_output
from utils.utils import (get_param_groups, get_params_groups, clip_gradients_value, count_params)


from hyperbolic_lib.lib.geoopt.optim import RiemannianSGD, RiemannianAdam, RiemannianAdamW


def trainNetworkMultiple(net, trainloader, validloader, testloader, model_path=None, model=None, dataset=None, bs=64,
                 iterations=500, lr=5 * 1e-4, wd=None, repeat=None, sub=None, epochs=None, subject_weights=None,
                 verbose=True, save_model=False, save_results=True, early_stopping=False, grace_period=20,
                 hyperbolic=False, clip_grad=0, loss_1=None, loss_2=None, num_classes=0, **kwargs):
    best_test = 0
    best_val = 0
    test_from_best_val = 0
    update_best_test = 0

    device = kwargs['device']
    CE = nn.CrossEntropyLoss()
    hyperbolic = True
    cfg = CutFillConfig(min_frac=0.05, max_frac=0.15, fill_value=0.0, loss_on_full=False)

    if not hyperbolic:
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=lr, weight_decay=wd, momentum=0.6, nesterov=False)

    else:
        print("hyperbolic")
        param_groups = get_params_groups(net, loss_1, loss_2,
                                         weight_decay=wd) if loss_1 is not None else get_param_groups(net, lr, wd)
        optimizer = RiemannianAdam(param_groups, lr=lr, weight_decay=wd)

    # scheduler = None
    scheduler = MultiStepLR(
       optimizer, milestones=[50,100,150,200,250], gamma=0.1
    )

    bestLoss = 1e10
    bestLoss_early_stopping = 1e10

    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    test_accs = []
    test_losses = []

    no_improvement_count = 0

    dropout = kwargs['dropout']
    windows = kwargs['windows']
    processor = kwargs['pre_processor']
    encoder = kwargs['pre_encoder']
    cut = kwargs['cutfill']
    decoder = kwargs['learn_decoder']
    predecoder = kwargs['learn_predecoder']
    lora = kwargs['learn_lora']
    lora_lr = kwargs['lora_lr']

    count_params(net)
    for ite in range(iterations):
        net.train()
        acc_val = 0
        acc_tr = 0
        tr_len = 0
        val_len = 0

############################################# train ####################################################################

        for row in trainloader:
            x_orig, x_sub, yb = row
            tr_len += yb.shape[0]

            if kwargs['cutfill']:
                x_temp = x_orig.unsqueeze(1) if len(x_orig.shape) < 4 else x_orig
                x_orig, mask, _ = apply_cut_and_fill(x_temp, cfg)

            x_orig = x_orig.to(device)
            x_sub = x_sub.to(device)
            yb = yb.to(device)

            out = net(x_orig, x_sub)

            loss = CE(out, yb)

            optimizer.zero_grad()

            loss.backward(retain_graph=False)

            if clip_grad > 0:
                param_norms = clip_gradients_value(net, clip_grad, losses=[loss])

            optimizer.step()
            acc_tr += (torch.max(out, 1).indices == yb).sum().item()

            torch.cuda.empty_cache()
        if scheduler is not None:
            scheduler.step()

########################################## validate ####################################################################
        net.eval()
        TL = 0

        for row in validloader:
            x_orig, x_sub, yb = row

            x_orig = x_orig.to(device)
            x_sub = x_sub.to(device)
            yb = yb.to(device)

            val_len += yb.shape[0]

            with torch.no_grad():
                out = net(x_orig, x_sub)
                acc_val += (torch.max(out, 1).indices == yb).sum().item()

                TL += CE(out,yb)

########################################## logging and checks ##########################################################
        if early_stopping:
            if TL < bestLoss_early_stopping:
                bestLoss_early_stopping = TL
                no_improvement_count = 0
            else:
                no_improvement_count += 1

            if no_improvement_count >= grace_period:
                if verbose:
                    print(f'Early stopping at iteration {ite} as no improvement has been observed in {grace_period} '
                          f'iterations for the validation loss.')
                final_path = os.path.join(model_path, f'repeat{repeat}_sub{sub}_epochs{epochs}_lr{lr}_wd{wd}.pt')
                torch.save(net, final_path)
                break

        if verbose:
            print('')
            print(f'Iteration{ite}=====')
            print(f'train_loss:{loss:.4f}    val_loss:{TL / len(validloader):.4f}')
            print(f'train_acc:{acc_tr / tr_len:.4f}    val_acc:{acc_val / val_len:.4f}')

        if acc_val / val_len > best_val:
            best_val = acc_val / val_len
            update_best_test = 1

            if sub == 'all':
                torch.save(net, f"./checkpoints/{dataset}_{model}_win{windows}_bs{bs}_lr{lr}_wd{wd}_dp{dropout}_llr{lora_lr}_proc{processor}_enc{encoder}_cut{cut}_dec{decoder}_predec{predecoder}_lora{lora}_all.pt")
            else:
                torch.save(net, f"./checkpoints/{dataset}_{model}_best_{sub}.pt")

        train_losses.append(loss.detach().cpu())
        train_accs.append(acc_tr / tr_len)
        val_losses.append(TL / len(validloader))
        val_accs.append(acc_val / val_len)

############################################## test ####################################################################
        if dataset.startswith('bcicha'):
            test_acc, test_loss = testNetwork_auc(net, testloader, device)
        else:
            test_acc, test_loss = testNetwork(net, testloader, device)
        test_accs.append(test_acc)
        test_losses.append(test_loss)

        test_from_best_acc = test_acc if update_best_test == 1 else test_from_best_val
        update_best_test = 0

        best_test = test_acc if test_acc > best_test else best_test
        if verbose:
            print(f'test_acc:{test_acc.__round__(3)}')
            print(f"best_test:{best_test.__round__(3)}")
            print(f"best_loss:{bestLoss.__round__(3)}")
            print(f"test_from_best_val:{test_from_best_val.__round__(3)}")
            print(f"test_from_best_acc:{test_from_best_acc.__round__(3)}")

        if TL / len(validloader) < bestLoss:
            bestLoss = (TL / len(validloader)).item()
            test_from_best_val = test_acc

    if save_results:
        save_output(
            model_name=model, dataset=dataset, sub=sub, bs=bs, lr=lr, wd=wd, alpha=subject_weights,
            epochs=epochs,
            results=[
                np.array([loss.cpu().item() if isinstance(loss, torch.Tensor) else loss for loss in train_losses]),
                np.array([acc.cpu().item() if isinstance(acc, torch.Tensor) else acc for acc in train_accs]),
                np.array([loss.cpu().item() if isinstance(loss, torch.Tensor) else loss for loss in val_losses]),
                np.array([acc.cpu().item() if isinstance(acc, torch.Tensor) else acc for acc in val_accs]),
                np.array([loss if not isinstance(loss, torch.Tensor) else loss.cpu().item() for loss in test_losses]),
                np.array([acc.cpu().item() if isinstance(acc, torch.Tensor) else acc for acc in test_accs])
            ],
            **kwargs
        )

    return TL/len(validloader), acc_val/val_len


def testNetwork(net, testloader, device):
    net.eval()
    acc = 0
    test_len = 0
    total_loss = 0.0
    CE = nn.CrossEntropyLoss()

    for row in testloader:
        x_orig, x_sub, yb = row
        x_orig = x_orig.to(device)
        x_sub = x_sub.to(device)
        yb = yb.to(device)

        test_len += yb.shape[0]
        with torch.no_grad():
            pred = net(x_orig, x_sub)
            acc += (torch.max(pred, 1).indices == yb).sum().item()
            total_loss += CE(pred, yb).item()

    return acc / test_len, total_loss / len(testloader)

def testNetwork_auc(net, testloader, device):
    net.eval()
    softmax = nn.Softmax(dim=1)
    y_pred = torch.empty(0, device=device)
    y_true = torch.empty(0, device=device)
    CE = nn.CrossEntropyLoss()
    total_loss = 0.0

    for row in testloader:
        with torch.no_grad():
            x_orig, x_sub, yb = row
            x_orig = x_orig.to(device)
            x_sub = x_sub.to(device)
            yb = yb.to(device)

            pred = net(x_orig, x_sub)
            total_loss += CE(pred, yb).item()

            y_pred = torch.cat((y_pred, pred[:, 1]), 0)
            y_true = torch.cat((y_true, yb), 0)

    auc = ras(y_true.detach().cpu().numpy(),
              y_pred.detach().cpu().numpy())
    return auc, total_loss / len(testloader)