import os
import time

import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

from library import results_json
from library import configs
from library import models
from library import model_io

from difflogic import PackBitsTensor, CompiledLogicNet

BITS_TO_TORCH_FLOATING_POINT_TYPE = {
    16: torch.float16,
    32: torch.float32,
    64: torch.float64
}

def accuracy(y_pred: torch.Tensor, y: torch.Tensor) -> float:
    # return (y_pred.argmax(-1) == y).to(torch.float32).mean().item()
    with torch.no_grad():
        y_pred = y_pred.detach()
        return (y_pred.argmax(-1) == y).float().mean().item()
def accuracy_no_groupsum(y_pred: torch.Tensor, y: torch.Tensor, train_config) -> float:
    # y_pred = (y_pred >= 0.5).float()
    num_classes = train_config.num_classes
    block_size = y.shape[1] // num_classes

    pred_blocks = y_pred.view(-1, num_classes, block_size).sum(dim=2)
    pred_class = pred_blocks.argmax(dim=1)

    # Should be a block full of 1s
    true_blocks = y.view(-1, num_classes, block_size).sum(dim=2)
    true_class = true_blocks.argmax(dim=1)

    return (pred_class == true_class).float().mean().item()

def eval(model: torch.nn.Module, loader: torch.utils.data.DataLoader, train_mode: bool, config: configs.DifflogicConfig, results: results_json.ResultsJSON=None, measure=accuracy, last_epoch = False) -> float:
    device = config.data_config.device
    if loader is None:
        return -1

    orig_mode = model.training
    model.train(mode=train_mode)
    with torch.no_grad():
        accs = []
        weights = []
        for x, y in loader:
            y = y.to(device=device)
            x = x.to(device=device)
            # x = model(x)
            for i in range(len(model)):
                x = model[i](x)
                # Tap into second to last layer to get output logit stats
                """
                if i == len(model) - 2 and config.experiment_config.store_logit_stats and results and last_epoch:
                    for j in range(y.size(0)):
                        label = y[j].item()
                        output = x[j].tolist()
                        results.store_logit_stats(label, output, num_classes=models.num_classes_of_dataset(config.data_config.dataset))
                """
            """
            if config.experiment_config.store_raw_values and results:
                for i in range(y.size(0)):
                    label = y[i].item()
                    output = x[i].tolist()
                    results.store_raw_values(label, output)
            """
            acc = measure(x, y)
            accs.append(acc)
            weights.append(len(y))
        # res = np.mean(accs)
        res = np.average(accs, weights=weights)
        model.train(mode=orig_mode)
    
    return res.item()


def packbits_eval(model, loader):
    orig_mode = model.training
    with torch.no_grad():
        model.eval()
        res = np.mean(
            [
                (model(PackBitsTensor(x.to('cuda').reshape(x.shape[0], -1).round().bool())).argmax(-1) == y.to(
                    'cuda')).to(torch.float32).mean().item()
                for x, y in loader
            ]
        )
        model.train(mode=orig_mode)
    return res.item()

def compile_and_eval(
        model: torch.nn.Module, 
        train_loader: torch.utils.data.DataLoader,
        validation_loader: torch.utils.data.DataLoader, 
        test_loader: torch.utils.data.DataLoader, 
        config: configs.DifflogicConfig
    ) -> None:
    print('\n' + '='*80)
    print(' Converting the model to C code and compiling it...')
    print('='*80)

    total_neurons = config.model_config.num_neurons * config.model_config.num_layers

    for num_bits in config.compilation_config.num_bits:
        accuracies = [None] * config.compilation_config.num_repetitions
        timings = [None] * config.compilation_config.num_repetitions  # TODO: Implement timing
        for repetition in range(config.compilation_config.num_repetitions):
            lib_dir = config.compilation_config.lib_dir
            experiment_id = config.experiment_config.experiment_id if config.experiment_config.experiment_id is not None else 0
            os.makedirs(lib_dir, exist_ok=True)
            save_lib_path = 'lib/{:08d}_{}.so'.format(experiment_id, num_bits)

            compiled_model = CompiledLogicNet(
                model=model,
                num_bits=num_bits,
                cpu_compiler=config.compilation_config.cpu_compiler,
                verbose=config.compilation_config.verbose,
            )

            compiled_model.compile(
                opt_level=1 if total_neurons < 50_000 else 0,
                save_lib_path=save_lib_path,
                verbose=config.compilation_config.verbose,
            )

            correct, total = 0, 0
            with torch.no_grad():
                start_time = time.time()
                for (data, labels) in torch.utils.data.DataLoader(test_loader.dataset, batch_size=int(1e6), shuffle=False):
                    data = torch.nn.Flatten()(data).bool().numpy()

                    output = compiled_model(data, verbose=True)

                    correct += (output.argmax(-1) == labels).float().sum()
                    total += output.shape[0]
                end_time = time.time()

            acc3 = correct / total
            accuracies[repetition] = acc3
            timings[repetition] = end_time - start_time
        print('COMPILED MODEL', num_bits)
        print('Accuracies:', accuracies, 'Mean:', np.mean(accuracies), 'Std:', np.std(accuracies))
        print('Timings:', timings, 'Mean:', np.mean(timings), 'Std:', np.std(timings))
        print('='*80)


def train_step(model: torch.nn.Module, x: torch.Tensor, y: torch.Tensor, loss_fn, optimizer: torch.optim.Optimizer, measure, train_config) -> float:
    y_pred = model(x)
    x = x.view(x.size(0), -1)
    # loss = loss_fn(y_pred, x)
    loss = loss_fn(y_pred, y)
    if train_config.l1_regularization != 0:
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        loss = loss + train_config.l1_regularization * l1_norm
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item(), measure(y_pred, y)

def train(
        model: torch.nn.Module, 
        loss_fn: torch.nn.Module,
        optimizer: torch.optim.Optimizer, 
        train_loader: torch.utils.data.DataLoader, 
        validation_loader: torch.utils.data.DataLoader,
        binarized_loader: torch.utils.data.DataLoader,
        test_loader: torch.utils.data.DataLoader,
        test_loader_bin: torch.utils.data.DataLoader,
        results: results_json.ResultsJSON,
        config: configs.DifflogicConfig
    ) -> None:
    print("Training...")

    best_acc = 0
    measure = accuracy
    train_epoch_samples = len(train_loader.dataset)
    for epoch in range(config.train_config.num_epochs):
        if config.train_config.decrease_tau is not None and config.model_config.use_groupsum:
            config.model_config.tau = config.train_config.decrease_tau[max(0, len(config.train_config.decrease_tau) - 1 - (epoch//10))]
            model[-1].tau = config.model_config.tau
            print(f"Tau: {model[-1].tau}")
        epoch_loss = 0
        epoch_acc = 0
        for x, y in tqdm(
                train_loader,
                desc=f'epoch {epoch}, iteration',
                total=len(train_loader),
        ):
            x = x.to(
                dtype=BITS_TO_TORCH_FLOATING_POINT_TYPE[config.train_config.training_bit_count], 
                device=config.data_config.device
            )
            y = y.to(config.data_config.device)

            batch_loss, batch_acc = train_step(model, x, y, loss_fn, optimizer, measure, config.train_config)
            epoch_loss += batch_loss
            epoch_acc += batch_acc * (len(x) / train_epoch_samples)
        
        print(f'Epoch {epoch} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.1%}')

        if epoch % config.train_config.eval_freq == 0:
            # Evaluate the model on the training and validation sets
            train_accuracy_eval_mode = eval(model, train_loader, train_mode=False, config=config, measure=measure)
            valid_accuracy_eval_mode = eval(model, validation_loader, train_mode=False, config=config, results=results, measure=measure, last_epoch=epoch==config.train_config.num_epochs - 1)
            binarized_accuracy_eval_mode = eval(model, binarized_loader, train_mode=False, config=config, results=results, measure=measure, last_epoch=epoch==config.train_config.num_epochs - 1)
            test_accuracy_eval_mode = eval(model, test_loader, train_mode=False, config=config, results=results, measure=measure, last_epoch=epoch==config.train_config.num_epochs - 1)
            test_binarized_accuracy_eval_mode = eval(model, test_loader_bin, train_mode=False, config=config, results=results, measure=measure, last_epoch=epoch==config.train_config.num_epochs - 1)
            if config.train_config.extensive_eval:
                train_accuracy_train_mode = eval(model, train_loader, train_mode=True, config=config, measure=measure)
                valid_accuracy_train_mode = eval(model, validation_loader, train_mode=True, config=config, measure=measure)
                binarized_accuracy_train_mode = eval(model, binarized_loader, train_mode=True, config=config, measure=measure)
                test_accuracy_train_mode = eval(model, test_loader, train_mode=True, config=config, measure=measure)
                test_binarized_accuracy_train_mode = eval(model, test_loader_bin, train_mode=True, config=config, measure=measure)
            else:
                train_accuracy_train_mode = -1
                valid_accuracy_train_mode = -1
                binarized_accuracy_train_mode = -1
                test_accuracy_train_mode = -1
                test_binarized_accuracy_train_mode = -1

            eval_results = {
                'train_acc_eval_mode': train_accuracy_eval_mode,
                'train_acc_train_mode': train_accuracy_train_mode,
                'valid_acc_eval_mode': valid_accuracy_eval_mode,
                'valid_acc_train_mode': valid_accuracy_train_mode,
                'bin_acc_eval_mode': binarized_accuracy_eval_mode,
                'bin_acc_train_mode': binarized_accuracy_train_mode,
                'test_acc_eval_mode': test_accuracy_eval_mode,
                'test_acc_train_mode': test_accuracy_train_mode,
                'test_bin_acc_eval_mode': test_binarized_accuracy_eval_mode,
                'test_bin_acc_train_mode': test_binarized_accuracy_train_mode
            }
            
            if config.experiment_config.experiment_id is not None:
                results.store_results(eval_results)
                print(eval_results)
            else:
                print(eval_results)

            if config.train_config.save_model_on == 'valid':
                if valid_accuracy_eval_mode > best_acc:
                    best_acc = valid_accuracy_eval_mode
                    if config.experiment_config.experiment_id is not None:
                        results.store_final_results(eval_results)
                        model_io.save_model(model, config=config, model_path="./models/", model_name=config.experiment_config.experiment_name)
                        print('Model saved')
                    else:
                        print('IS THE BEST UNTIL NOW.')
            elif config.train_config.save_model_on == 'bin':
                if binarized_accuracy_eval_mode > best_acc:
                    best_acc = binarized_accuracy_eval_mode
                    if config.experiment_config.experiment_id is not None:
                        results.store_final_results(eval_results)
                        model_io.save_model(model, config=config, model_path="./models/", model_name=config.experiment_config.experiment_name)
                        print('Model saved')
                    else:
                        print('IS THE BEST UNTIL NOW.')
            else:
                raise ValueError("Specify when to save the model: e.g. 'valid' or 'bin' (config.train_config.save_model_on)!")

            if config.experiment_config.experiment_id is not None:
                results.write_to_disk()
        torch.cuda.empty_cache()
        import gc
        gc.collect()

def test(
        model: torch.nn.Module, 
        train_loader: torch.utils.data.DataLoader,
        validation_loader: torch.utils.data.DataLoader,
        test_loader: torch.utils.data.DataLoader, 
        results: results_json.ResultsJSON, 
        config: configs.DifflogicConfig,
    ) -> None:

    measure = accuracy

    test_accuracy_eval_mode = eval(model, test_loader, train_mode=False, config=config, measure=measure)
    test_accuracy_train_mode = eval(model, test_loader, train_mode=True, config=config, measure=measure)
    if config.test_config.extensive_eval:
        train_accuracy_eval_mode = eval(model, train_loader, train_mode=False, config=config, measure=measure)
        valid_accuracy_eval_mode = eval(model, validation_loader, train_mode=False, config=config, measure=measure)
        train_accuracy_train_mode = eval(model, train_loader, train_mode=True, config=config, measure=measure)
        valid_accuracy_train_mode = eval(model, validation_loader, train_mode=True, config=config, measure=measure)
    else:
        train_accuracy_eval_mode = -1
        valid_accuracy_eval_mode = -1
        train_accuracy_train_mode = -1
        valid_accuracy_train_mode = -1    

    r = {
        'test_acc_eval_mode': test_accuracy_eval_mode,
        'test_acc_train_mode': test_accuracy_train_mode,
        'train_acc_eval_mode': train_accuracy_eval_mode,
        'train_acc_train_mode': train_accuracy_train_mode,
        'valid_acc_eval_mode': valid_accuracy_eval_mode,
        'valid_acc_train_mode': valid_accuracy_train_mode,
    }

    if config.test_config.packbits_eval:
        r['train_acc_eval'] = packbits_eval(model, train_loader)
        r['valid_acc_eval'] = packbits_eval(model, train_loader)
        r['test_acc_eval'] = packbits_eval(model, test_loader)
    
    if config.test_config.compile_model:
        compile_and_eval(
            model, train_loader, validation_loader, test_loader, config
        )

    if config.experiment_config.experiment_id is not None:
        results.store_results(r)
        results.store_final_results(r)
        results.write_to_disk()
    else:
        print(r)