import copy
import random

import torch
import torch.nn as nn
import torch.optim as optim

from architectures import architectures

#
# Select parameters
#

# MODULE_TO_USE = 'fnn'
MODULE_TO_USE = 'smfr'

# Training parameters
TASK_VARIANT__INCREMENT_RESULT = True
SMFR_REGULARIZATION_LOSS_BOUNDARY = 20  # Set this to None to deactivate the SMFR's regularization loss
NUM_EPOCHS = 100000
BATCH_SIZE = 100
RANDOM_SEED = 0


########################################################################################################################
# Modify the parameters ABOVE this line to try different experiments
########################################################################################################################

# Define some constants
BLOCK_SIZE = 10
NUM_TASK_INPUTS = 6
NUM_TASK_OUTPUTS = 5
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
RNG = random.Random(RANDOM_SEED)
NUMBER_TO_TENSOR = [torch.tensor([1.0 if j == i else 0.0 for j in range(10)]).float().to(DEVICE) for i in range(10)]
print(f"Using device: {DEVICE}")

def run_a_test():
    # Create an instance of the module
    module, regularization_loss_accumulator = get_module(MODULE_TO_USE)
    # Define the loss function and optimizer
    # Note that this task uses an MSE-loss.
    # Else the values would become strange at intermediate iterations because the sigmoid kills the gradient.
    criterion = nn.MSELoss()
    optimizer = optim.Adam(module.parameters(), lr=3e-4)
    # Train the module
    for epoch in range(NUM_EPOCHS):
        loss, average_success_rate = apply_batch(
            module, regularization_loss_accumulator, optimizer, criterion, num_iterations=2, is_training=True
        )
        # Print the loss every 1000 epochs
        if (epoch + 1) % 1000 == 0:
            print(f"Epoch [{epoch + 1}/{NUM_EPOCHS}], Loss: {loss.item():.4f}")
            # Test the module on OOD iterations
            with torch.no_grad():
                for i in range(1, 11):
                    loss, average_success_rate = apply_batch(
                        module, regularization_loss_accumulator, optimizer, criterion, num_iterations=i, is_training=False
                    )
                    print(f"    {i} iterations: Success rate={average_success_rate:.4f}")


def apply_batch(module, regularization_loss_accumulator, optimizer, criterion, num_iterations, is_training=True):
    # Get task data
    original_state_tensors, target_tensors, input_tensors_list = get_data_for_algo_task(num_iterations)
    # Apply N iterations
    x = original_state_tensors
    for i, input_tensors in enumerate(input_tensors_list):
        assert len(input_tensors) == 1
        x = x + input_tensors
        assert len(x) == NUM_TASK_INPUTS
        assert all(a.shape == (BATCH_SIZE, BLOCK_SIZE) for a in x)
        # Forward pass
        x = module(x)
        assert len(x) == NUM_TASK_OUTPUTS
        assert len(target_tensors) == NUM_TASK_OUTPUTS
    # Accumulate losses over all output blocks
    loss = None
    for output_tensor, target_tensor in zip(x, target_tensors):
        assert output_tensor.shape == (BATCH_SIZE, BLOCK_SIZE)
        if loss is None:
            loss = criterion(output_tensor, target_tensor)
        else:
            loss += criterion(output_tensor, target_tensor)
    # Get the success rate
    average_success_rate = get_success_rate_over_all_outputs(x, target_tensors)
    # Add the regularization loss
    train_loss = loss
    for reg_loss in regularization_loss_accumulator:
        train_loss += reg_loss
    regularization_loss_accumulator.clear()
    # Backward and optimize
    if is_training:
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
    return loss, average_success_rate


def get_success_rate_over_all_outputs(tensor_list_a, tensor_list_b):
    tmp_a = torch.stack([a.argmax(dim=1) for a in tensor_list_a], dim=1)
    tmp_b = torch.stack([a.argmax(dim=1) for a in tensor_list_b], dim=1)
    assert tmp_a.shape == (BATCH_SIZE, NUM_TASK_OUTPUTS)
    assert tmp_b.shape == (BATCH_SIZE, NUM_TASK_OUTPUTS)
    evaluation_target_tensor = tmp_a.eq(tmp_b).float()
    success_rate = evaluation_target_tensor.mean()
    return success_rate


def get_module(name_of_module):
    regularization_loss_accumulator = []
    if name_of_module == 'fnn':
        module = architectures.FnnWithBlocks(
            input_sizes=[BLOCK_SIZE] * NUM_TASK_INPUTS,
            output_sizes=[BLOCK_SIZE] * NUM_TASK_OUTPUTS,
            intermediate_sizes=[300, 300, 300],
        )
    elif name_of_module == 'smfr':
        module = architectures.SMFR(
            block_size=BLOCK_SIZE,
            num_input_tensors=NUM_TASK_INPUTS,
            num_output_tensors=NUM_TASK_OUTPUTS,
            stack_width=7,
            stack_depth=5,
            multiplexer_attention_type='softmax',
            sizes_of_contained_fnns=[100, 100],
            regularization_loss_accumulator=regularization_loss_accumulator,
            regularization_loss_boundary=SMFR_REGULARIZATION_LOSS_BOUNDARY,
        )
    else:
        raise ValueError(name_of_module)
    module.to(DEVICE)
    return module, regularization_loss_accumulator


def get_data_for_algo_task(num_iterations):
    indices_of_variables_formulas = [
        # These numbers are the indices of the variables in the formula that is applied on each iteration.
        # Each tuple is one formula.
        # Each formula has the following form:
        # "Set $4 = $2 if ($0 > $1) else $3"
        (0, 1, 2, 3, 4),
        (1, 2, 3, 4, 0),
        (2, 3, 4, 0, 1),
        (3, 4, 0, 1, 2),
        (4, 0, 1, 2, 3),
    ]
    inputs_list_list = []
    original_states_list = []
    targets_list = []
    for batch_index in range(BATCH_SIZE):
        inputs_list = []
        for iteration in range(num_iterations):
            # Pick a formula to apply and add an indicator for it to the input
            formula_index = RNG.randint(0, len(indices_of_variables_formulas) - 1)
            formula = indices_of_variables_formulas[formula_index]
            inputs_list.append([formula_index])
            # Use random states to start with, or take the states from the previous iteration
            if iteration == 0:
                original_states = [RNG.randint(0, 9) for _ in range(5)]
                states = original_states
            states = copy.copy(states)
            # Apply the formula to the states
            tmp = states[formula[2]] if states[formula[0]] > states[formula[1]] else states[formula[3]]
            if TASK_VARIANT__INCREMENT_RESULT:
                tmp = (tmp + 1) % 10
            states[formula[4]] = tmp
        original_states_list.append(original_states)
        targets_list.append(states)
        inputs_list_list.append(inputs_list)
    # Transform the human-readable numbers into tensors for the model
    input_tensors_list = [
        [
            torch.stack([NUMBER_TO_TENSOR[a] for a in batch_values], dim=0)
            for batch_values in zip(*batch_values_list)
        ]
        for batch_values_list in zip(*inputs_list_list)
    ]
    original_state_tensors = [
        torch.stack([NUMBER_TO_TENSOR[a] for a in vals], dim=0)
        for vals in zip(*original_states_list)
    ]
    target_tensors = [
        torch.stack([NUMBER_TO_TENSOR[a] for a in vals], dim=0)
        for vals in zip(*targets_list)
    ]
    assert len(input_tensors_list) == num_iterations
    assert all(len(a) == 1 for a in input_tensors_list)
    return original_state_tensors, target_tensors, input_tensors_list

run_a_test()