from tqdm import tqdm
import torch
import sys
import math

from torch import float16
import pandas as pd
import time
import git

from .trainer_utils import full_count_params, count_params_test, count_params_train, get_dynamic_ranks, accuracy_top_k


def train(NN, optimizer, train_loader, validation_loader, test_loader, criterion, metric, epochs,
          metric_name='accuracy', device='cpu', count_bias=False, path=None, epoch_status_bar=False,
          fine_tune=False, scheduler=None, save_weights=True, save_progress=False, save_name=''):
    """
    INPUTS:
    NN : neural network with custom layers and methods to optimize with dlra
    train/validation/test_loader : loader for datasets
    criterion : loss function
    metric : metric function
    epochs : number of epochs to train
    metric_name : name of the used metric
    count_bias : flag variable if to count biases in params_count or not
    path : path string for where to save the results
    OUTPUTS:
    running_data : Pandas dataframe with the results of the run
    """

    running_data = pd.DataFrame(data=None, columns=['epoch', 'tau', 'learning_rate', 'train_loss',
                                                    'train_' + metric_name + '(%)', 'validation_loss',
                                                    'validation_' + metric_name + '(%)',
                                                    'top5_validation_' + metric_name + '(%)',
                                                    'test_' + metric_name + '(%)',
                                                    'ranks', '# effective parameters conv', 'cr_test_conv (%)',
                                                    '# effective parameters train conv', 'cr_train_conv (%)',
                                                    '# effective parameters train with grads conv',
                                                    'cr_train_grads_conv (%)',
                                                    '# effective parameters', 'cr_test (%)',
                                                    '# effective parameters train', 'cr_train (%)',
                                                    '# effective parameters train with grads', 'cr_train_grads (%)',
                                                    'timing batch forward', 'git commit'])
    repo = git.Repo(search_parent_directories=True)
    git_commit = str(repo.head.object.hexsha)

    total_params_full = full_count_params(NN.lr_model, count_bias)
    total_params_full_grads = full_count_params(NN.lr_model, count_bias, with_grads=True)

    total_params_full_linear = full_count_params(NN.lr_model, count_bias, with_grads=False, count_linear=True)
    total_params_full_grads_linear = full_count_params(NN.lr_model, count_bias, with_grads=True, count_linear=True)

    file_name = path

    # beta = optimizer.beta
    def accuracy(outputs, labels):
        return torch.sum(torch.tensor(torch.argmax(outputs.detach(), axis=1) == labels, dtype=float16))

    metric = accuracy
    batch_size = train_loader.batch_size

    if not fine_tune:

        # if path is not None:
        #    file_name += '.csv'

        for epoch in tqdm(range(epochs)):


            NN.eval()
            with torch.no_grad():
                k = len(validation_loader)
                batch_size = validation_loader.batch_size
                loss_hist_val = 0.0
                acc_hist_val = 0.0
                acc_top5_hist_val = 0.0
                for i, data in enumerate(validation_loader):  # validation
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = NN(inputs).detach()#.to(device)
                    loss_val = criterion(outputs, labels)
                    loss_hist_val += float(loss_val.item()) / (k * batch_size)
                    acc_hist_val += float(metric(outputs, labels)) / (k * batch_size)
                    acc_top5_hist_val += float(accuracy_top_k(outputs, labels, topk=(5,))[5]) / (k * batch_size)

                if test_loader != None:
                    k = len(test_loader)
                    loss_hist_test = 0.0
                    acc_hist_test = 0.0
                    batch_size = test_loader.batch_size
                    for i, data in enumerate(test_loader):  # validation
                        inputs, labels = data
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = NN(inputs).detach()#.to(device)
                        loss_test = criterion(outputs, labels)
                        loss_hist_test += float(loss_test.item()) / (k * batch_size)
                        acc_hist_test += float(metric(outputs, labels)) / (k * batch_size)
                else:
                    loss_hist_test = -1
                    acc_hist_test = -1

            print(f'epoch {epoch}, acc_val {acc_hist_val}---------------------------------------------')
            loss_hist = 0
            acc_hist = 0
            k = len(train_loader)
            average_batch_time = 0.0

            NN.train()
            for i, data in enumerate(train_loader):  # train
                optimizer.zero_grad()
                start = time.time()
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                def closure():
                    loss = NN.populate_gradients(inputs, labels, criterion, step='core')
                    return loss

                loss, outputs = NN.populate_gradients(inputs, labels, criterion)

                with torch.no_grad():
                    NN.set_step('test')
                    outputs = NN(inputs).detach()#.to(device)
                    loss_hist += float(criterion(outputs, labels).item()) / (k * batch_size)
                    acc_hist += float(metric(outputs, labels)) / (k * batch_size)

                stop = time.time() - start
                average_batch_time += stop / k

                if epoch_status_bar:
                    sys.stdout.write('\r')
                    t = int(float(i) / float(k) * 15)
                    sys.stdout.write(
                        "Current Acc {:.3f}, Loss {:.3f}. "
                        "Epoch progress: [{:{}}] {:.1f}%".format(acc_hist, loss_hist, "=" * t, 15, (100 / (k - 1) * i)))
                    sys.stdout.flush()

                if math.isnan(loss_hist):
                    print("Training diverged! Loss is nan")
                    exit(1)
                    # ValueError("Training diverged! Loss is nan")
                optimizer.step(closure=closure)


            ranks = get_dynamic_ranks(NN.lr_model)
            print('\n')
            for i in range(len(ranks)):
                print(f'rank layer {i} {ranks[i]}')
            print('\n')

            #### convolution compression ratio calculation
            params_test = count_params_test(NN.lr_model, count_bias)
            cr_test = round(params_test / total_params_full, 3)
            params_train = count_params_train(NN.lr_model, count_bias)
            cr_train = round(params_train / total_params_full, 3)
            params_train_grads = count_params_train(NN.lr_model, count_bias, True)
            cr_train_grads = round(params_train_grads / total_params_full_grads, 3)

            ##### compute compression ratio counting also linear layers
            params_test_global = count_params_test(NN.lr_model, count_bias, True)
            cr_test_global = round(params_test_global / total_params_full_linear, 3)
            params_train_global = count_params_train(NN.lr_model, count_bias, False, True)
            cr_train_global = round(params_train_global / total_params_full_linear, 3)
            params_train_grads_global = count_params_train(NN.lr_model, count_bias, True, True)
            cr_train_grads_global = round(params_train_grads_global / total_params_full_grads_linear, 3)


            print(f'cr: test {round(100 * (1 - cr_test), 4)} train_grads {round(100 * (1 - cr_train_grads), 4)}')
            print(
                f'epoch[{epoch}/{epochs}]: loss: {loss_hist:9.4f} | {metric_name}: {acc_hist:9.4f} | val loss: {loss_hist_val:9.4f} | val {metric_name}:{acc_hist_val:9.4f}')
            print('=' * 100)

            compression_hyperparam = NN.tau
            lr = round(float(optimizer.integrator.param_groups[0]['lr']), 4)
            epoch_data = [epoch, compression_hyperparam, lr, round(loss_hist, 3),
                          round(acc_hist * 100, 4), round(loss_hist_val, 3),
                          round(acc_hist_val * 100, 4), round(acc_top5_hist_val * 100, 4),
                          round(acc_hist_test * 100, 4), ranks, params_test,
                          round(100 * (1 - cr_test), 4),
                          params_train, round(100 * (1 - cr_train), 4), params_train_grads,
                          round(100 * (1 - cr_train_grads), 4),
                          params_test_global, round(100 * (1 - cr_test_global), 4),
                          params_train_global, round(100 * (1 - cr_train_global), 4), params_train_grads_global,
                          round(100 * (1 - cr_train_grads_global), 4), average_batch_time, git_commit]

            running_data.loc[epoch] = epoch_data
            print(file_name)
            if file_name is not None and (epoch % 1 == 0 or epoch == epochs - 1) and save_progress:
                running_data.to_csv(path + save_name + '.csv')
                try:
                    running_data.to_csv(path + '/drive/MyDrive/nips2023_results/' + save_name + '.csv')
                    # running_data.to_csv(path + save_name + '.csv')
                except:
                    print(
                        "Tried: " + path + '/drive/MyDrive/nips2023_results/' + save_name + '.csv for additional backup, but:')
                    print('drive not found')

            if scheduler is not None:
                scheduler.step(loss_hist)

            if epoch == 0:
                best_val_loss = loss_hist_val

            if loss_hist_val < best_val_loss and save_weights:
                print('save')
                # print(list(NN.lr_model.state_dict().values())[0][0,0])
                torch.save(NN, path + save_name + '.pt')
                best_val_loss = loss_hist_val
                try:
                    # torch.save(NN, path + save_name + '.pt')
                    torch.save(NN, path + '/drive/MyDrive/nips2023_results/' + save_name + '.pt')
                except:
                    print(
                        "Tried: " + path + '/drive/MyDrive/nips2023_results/' + save_name + '.pt for additional backup, but:')
                    print('drive not found')

        return running_data

    else:

        if path is not None:
            file_name += '_finetune.csv'

        for epoch in tqdm(range(epochs)):

            print(f'epoch {epoch}---------------------------------------------')
            loss_hist = 0
            acc_hist = 0
            batch_size = train_loader.batch_size
            k = len(train_loader)
            average_batch_time = 0.0

            NN.train()
            for i, data in enumerate(train_loader):  # train
                # NN.zero_grad()
                optimizer.zero_grad()
                start = time.time()
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = NN(inputs)#.to(device)
                loss = criterion(outputs, labels)
                loss.backward()
                loss_hist += float(loss.item()) / (k * batch_size)
                acc_hist += float(metric(outputs.detach(), labels)) / (k * batch_size)
                optimizer.S_finetune_step()
                stop = time.time() - start
                average_batch_time += stop / k

                if epoch_status_bar:
                    sys.stdout.write('\r')
                    t = int(float(i) / float(k) * 15)
                    sys.stdout.write(
                        "Current Acc {:.3f}, Loss {:.3f}. "
                        "Epoch progress: [{:{}}] {:.1f}%".format(acc_hist, loss_hist, "=" * t, 15, (100 / (k - 1) * i)))
                    sys.stdout.flush()

            NN.eval()
            with torch.no_grad():
                k = len(validation_loader)
                loss_hist_val = 0.0
                acc_hist_val = 0.0
                acc_top5_hist_val = 0.0
                batch_size = validation_loader.batch_size
                for i, data in enumerate(validation_loader):  # validation
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = NN(inputs).detach()#.to(device)
                    loss_val = criterion(outputs, labels)
                    loss_hist_val += float(loss_val.item()) / (k * batch_size)
                    acc_hist_val += float(metric(outputs, labels)) / (k * batch_size)
                    acc_top5_hist_val += float(accuracy_top_k(outputs, labels, topk=(5,))[5]) / (k * batch_size)

                if test_loader != None:
                    k = len(test_loader)
                    loss_hist_test = 0.0
                    acc_hist_test = 0.0
                    batch_size = test_loader.batch_size
                    for i, data in enumerate(test_loader):  # validation
                        inputs, labels = data
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = NN(inputs).detach()#.to(device)
                        loss_test = criterion(outputs, labels)
                        loss_hist_test += float(loss_test.item()) / (k * batch_size)
                        acc_hist_test += float(metric(outputs, labels)) / (k * batch_size)

                else:

                    loss_hist_test = -1
                    acc_hist_test = -1

            print(
                f'epoch[{epoch}]: loss: {loss_hist:9.4f} | {metric_name}: {acc_hist:9.4f} | val loss: {loss_hist_val:9.4f} | val {metric_name}:{acc_hist_val:9.4f}')
            print('=' * 100)

            ranks = get_dynamic_ranks(NN.lr_model)
            for i in range(len(ranks)):
                print(f'rank layer {i} {ranks[i]}')
            print('\n')

            #### convolution compression ratio calculation
            params_test = count_params_test(NN.lr_model, count_bias)
            cr_test = round(params_test / total_params_full, 3)
            params_train = count_params_train(NN.lr_model, count_bias)
            cr_train = round(params_train / total_params_full, 3)
            params_train_grads = count_params_train(NN.lr_model, count_bias, True)
            cr_train_grads = round(params_train_grads / total_params_full_grads, 3)

            ##### compute compression ratio counting also linear layers
            params_test_global = count_params_test(NN.lr_model, count_bias, True)
            cr_test_global = round(params_test_global / total_params_full_linear, 3)
            params_train_global = count_params_train(NN.lr_model, count_bias, False, True)
            cr_train_global = round(params_train_global / total_params_full_linear, 3)
            params_train_grads_global = count_params_train(NN.lr_model, count_bias, True, True)
            cr_train_grads_global = round(params_train_grads_global / total_params_full_grads_linear, 3)


            print(f'cr: test {round(100 * (1 - cr_test), 4)} train_grads {round(100 * (1 - cr_train_grads), 4)}')

            compression_hyperparam = NN.tau
            lr = round(float(optimizer.integrator.param_groups[0]['lr']), 4)
            epoch_data = [epoch, compression_hyperparam, lr, round(loss_hist, 3),
                          round(acc_hist * 100, 4), round(loss_hist_val, 3),
                          round(acc_hist_val * 100, 4), round(acc_top5_hist_val * 100, 4),
                          round(acc_hist_test * 100, 4), ranks, params_test,
                          round(100 * (1 - cr_test), 4),
                          params_train, round(100 * (1 - cr_train), 4), params_train_grads,
                          round(100 * (1 - cr_train_grads), 4),
                          params_test_global, round(100 * (1 - cr_test_global), 4),
                          params_train_global, round(100 * (1 - cr_train_global), 4), params_train_grads_global,
                          round(100 * (1 - cr_train_grads_global), 4), average_batch_time, git_commit]

            running_data.loc[epoch] = epoch_data

            if file_name is not None and (epoch % 5 == 0 or epoch == epochs - 1) and save_progress:

                running_data.to_csv(path + save_name + '_ft.csv')
                try:
                    # running_data.to_csv(path + save_name + '.csv')
                    running_data.to_csv(path + '/drive/MyDrive/nips2023_results/' + save_name + '.csv')
                except:
                    print('drive not found')

            if scheduler is not None:
                scheduler.step(loss_hist)

            if epoch == 0:
                best_val_loss = loss_hist_val

            if loss_hist_val < best_val_loss and save_weights:
                print('save')
                # print(list(NN.lr_model.state_dict().values())[0][0,0])
                torch.save(NN, path + save_name + '_ft.pt')
                best_val_loss = loss_hist_val
                try:
                    # torch.save(NN, path + save_name + '.pt')
                    torch.save(NN, path + '/drive/MyDrive/nips2023_results/' + save_name + '.pt')
                except:
                    print('drive not found')

        return running_data
