
import torch
from devinterp.optim.sgld import SGLD
from devinterp.slt.sampler import estimate_learning_coeff
from devinterp.utils import evaluate_ce
from tqdm import tqdm


def train_and_analyze(model, train_loader, test_loader, criterion, optimizer, device, epochs, eval_step=100, wandb_run=None,
                      include_layer_movement = True, float_input=True, llc_estimate_frequency=1, llc_estimation_params=None):
    '''
    A simple training loop that tracks the weights of a model over training, and estimates the learning coefficient.

    :param model: A Pytorch model with two additional methods, `get_weights` and `flatten_weights`. These are used to
    retrieve and track parameters. `flatten_weights` should return a 1D tensor of the flattened weights.
    :param train_loader: A Pytorch DataLoader object.
    :param test_loader: A Pytorch DataLoader object.
    :param criterion: The loss function to use during training. This should be a Pytorch loss function.
    :param optimizer: The optimizer to use during training. This should be a Pytorch optimizer.
    :param device: The device the model/data are on during training.
    :param epochs: Number of epochs to train for.
    :param eval_step: Number of steps between each evaluation on a batch of test data and weight movement logging.
    :param wandb_run: a wandb run object. If provided, logs will be saved to this run.
    :param include_layer_movement: Whether or not to log the movement of the individual layers during training.
    :param float_input: Whether or not the input needs to be converted to a float. Mostly used for data which needs
    to be embedded.
    :param llc_estimate_frequency: How frequently to estimate the LLC.
    :param llc_estimation_params: The parameters used when estimating the LLC. If `None`, uses a set of default parameters.
    For more information see the `devinterp` docs https://github.com/timaeus-research/devinterp
    :return:
    '''
    model.to(device)
    W_0 = model.get_weights()
    W_0_flat = model.flatten_weights(W_0)
    #prev_weights = W_0
    #prev_weights_flat = W_0_flat
    total_weights = W_0_flat.size()[0]
    total_steps = 0
    if wandb_run is not None:
        wandb_run.log({"num_params": total_weights})
    for epoch in range(epochs):
        running_loss = 0.0
        correct, total = 0, 0

        print(f"Epoch {epoch + 1}/{epochs}")
        train_progress = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training")

        for i, (inputs, targets) in train_progress:
            model.train()
            inputs, targets = inputs.to(device), targets.to(device)
            if float_input:
                inputs = inputs.type(torch.float)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            #running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total = targets.size(0)
            train_correct = (predicted == targets).sum().item()
            train_loss_val = loss.item()
            train_acc = 100 * train_correct / train_total

            train_progress.set_postfix({"Loss": train_loss_val, "Accuracy": train_acc})
            total_steps += 1

            # Compute the average eigenvalue every `eval_step`
            if (total_steps + 1) % eval_step == 0:
                all_layer_movement = {}
                all_layer_distance_from_zero = {}

                # Evaluate test loss on a batch
                W_n = model.get_weights()
                W_n_flat = model.flatten_weights(W_n)
                total_weight_movement = torch.norm(W_n_flat - W_0_flat)
                total_distance_from_zero = torch.norm(W_n_flat)
                if include_layer_movement:
                    layer_movement = {f"layer_{i}": torch.norm(W_n[i] - W_0[i]) for i in range(len(W_0))}
                    layer_movement["final_layer"] = torch.norm(W_n[len(W_0)-1] - W_0[len(W_0)-1])
                    layer_distance_from_zero = {f"layer_{i}": torch.norm(W_n[i]) for i in range(len(W_0))}
                    all_layer_movement = layer_movement
                    all_layer_distance_from_zero = layer_distance_from_zero
                test_inputs, test_targets = next(iter(test_loader))
                test_inputs, test_targets = test_inputs.to(device), test_targets.to(device)
                model.eval()
                if float_input:
                    test_inputs = test_inputs.type(torch.float)
                with torch.no_grad():
                    test_outputs = model(test_inputs)
                    test_loss = criterion(test_outputs, test_targets).item()
                    _, test_predicted = torch.max(test_outputs, 1)
                    total = test_targets.size(0)
                    correct = (test_predicted == test_targets).sum().item()
                    test_accuracy = 100 * correct / total
                print(f'Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(train_loader)}], '
                      f'Train Loss: {train_loss_val:.4f}, Test Loss: {test_loss:.4f},'
                      f'Train Acc: {train_acc:.4f}, Test Acc: {test_accuracy:.4f}')
                if wandb_run is not None:
                    wandb_run.log({
                        "Step": total_steps,
                        "Train Loss": train_loss_val,
                        "Train Accuracy": train_acc,
                        "Test Loss (Batch)": test_loss,
                        "Test Accuracy": test_accuracy,
                        'total_weight_movement': total_weight_movement.item(),
                        'total_distance_from_zero': total_distance_from_zero.item(),
                        **{f'LM_{k}': v.item() for k, v in all_layer_movement.items()},
                        **{f'DZ_{k}': v.item() for k, v in all_layer_distance_from_zero.items()},
                    })
                model.train()
        if epoch%llc_estimate_frequency==0:
            print("Estimating learning coefficient")
            if llc_estimation_params is None:
                llc_estimation_params = dict(optimizer_kwargs={"lr": 1e-5, "localization": 100.0}, num_chains=1,
                num_draws=400,
                num_burnin_steps=0,
                num_steps_bw_draws=1)
            estimate = estimate_learning_coeff(
                model,
                train_loader,
                evaluate=evaluate_ce,
                sampling_method= SGLD,
                device=device,
                verbose=False,
                **llc_estimation_params,
            )
            wandb_run.log({"llc_estimate":estimate})


