import numpy as np
import torch
import wandb


class MixupLoss(torch.nn.Module):

    def __init__(self, mixup_alpha, criterion, fixed=False) -> None:
        super().__init__()
        self.alpha = mixup_alpha
        self.criterion = criterion
        self.fixed = fixed
    
    def forward(self, model, x, target):
        shuffle = torch.randperm(len(target))
        if self.fixed:
            lam = self.alpha
        else:
            lam = np.random.beta(self.alpha, self.alpha)
        mix_data = lam * x + (1 - lam) * x[shuffle]
        out = model(mix_data)
        return lam * self.criterion(out, target) + (1 - lam) * self.criterion(out, target[shuffle])


class LSLogLoss(torch.nn.Module):

    def __init__(self, alpha=0.5) -> None:
        super().__init__()
        self.criterion = torch.nn.BCELoss()
        self.alpha = alpha
    
    def forward(self, out, target):
        smoothed = target.clone()
        mask = target > 1e-6
        smoothed[mask] = 1 - self.alpha / 2
        smoothed[~mask] = self.alpha / 2
        return self.criterion(out, smoothed)


class BinaryMixupLoss(torch.nn.Module):

    def __init__(self, alpha=0.5, fixed=False) -> None:
        super().__init__()
        self.criterion = torch.nn.BCELoss()
        self.alpha = alpha
        self.fixed = fixed
    
    def forward(self, model, x, target):
        shuffle = torch.randperm(len(target))
        if self.fixed:
            lam = self.alpha
        else:
            lam = np.random.beta(self.alpha, self.alpha)
        mix_data = lam * x + (1 - lam) * x[shuffle]
        out = model(mix_data)
        return self.criterion(out, lam * target + (1 - lam) * target[shuffle])


def reset_weights(m):
    reset_parameters = getattr(m, "reset_parameters", None)
    if callable(reset_parameters):
        m.reset_parameters()


def get_grad_norm(model):
    grad_norm = 0
    for p in model.parameters():
        grad_norm += p.grad.data.norm(2).item() ** 2
    return grad_norm**0.5


def get_model_param_tensor(model):
    flattened_params = []
    for param_tensor in model.parameters():
        flattened_params.append(torch.flatten(param_tensor))
    return torch.cat(flattened_params)


def get_model_evaluations(model, data_loader, device="cpu"):
    model.eval()
    softmax = torch.nn.Softmax(dim=1)
    output = None
    with torch.no_grad():
        for data, _ in data_loader:
            data = data.to(device)
            output = softmax(model(data))
    return output


def train(
    model,
    train_loader,
    loss_fn,
    optimizer,
    device="cpu",
):
    model.train()
    mixup_train = True if isinstance(loss_fn, MixupLoss) or isinstance(loss_fn, BinaryMixupLoss) else False
    avg_batch_loss = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        if mixup_train:
            loss = loss_fn(model, data, target)
        else:
            output = model(data)
            loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        avg_batch_loss += loss.item() / len(train_loader)

    return avg_batch_loss


def test(model, test_loader, device="cpu"):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            if output.shape[1] > 1:
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
            else:
                pred = output.round()
            correct += pred.eq(target.view_as(pred)).sum().item()

    return 100 * (1 - (correct / len(test_loader.dataset)))


def full_train_test_loop(
    model,
    task,
    train_loader,
    train_loss_fn,
    test_loader,
    optimizer,
    num_epochs,
    model_name,
    out_file,
    num_runs=10,
    wandb_run=None,
    device="cpu",
):
    train_errors, test_errors = [], []
    print(f"{model_name} results for {task}: ", file=out_file)
    for run in range(1, num_runs + 1):
        model.apply(reset_weights)
        epoch_train_errors, epoch_test_errors = [], []
        for epoch in range(1, num_epochs + 1):
            avg_batch_loss = train(
                model,
                train_loader,
                train_loss_fn,
                optimizer,
                device,
            )
            epoch_train_errors.append(test(model, train_loader, device))
            epoch_test_errors.append(test(model, test_loader, device))
            if wandb_run is not None:
                wandb_run.log(
                    {
                        f"{task}_Run_{run}_Avg_Batch_Loss": avg_batch_loss,
                        f"{task}_Run_{run}_Train_Error": epoch_train_errors[-1],
                        f"{task}_Run_{run}_Test_Error": epoch_test_errors[-1],
                    }
                )

        train_errors.append(np.array(epoch_train_errors))
        test_errors.append(np.array(epoch_test_errors))

    train_errors, test_errors = np.array(train_errors), np.array(test_errors)
    print("-------------------------------------------------\n", file=out_file)
    print(f"Average Train Error: {train_errors[:, -1].mean():.2f}", file=out_file)
    print(f"Train Error Stdev: {train_errors[:, -1].std():.2f}", file=out_file)

    print("-------------------------------------------------\n", file=out_file)
    print(f"Average Test Error: {test_errors[:, -1].mean():.2f}", file=out_file)
    print(f"Test Error Stdev: {test_errors[:, -1].std():.2f}", file=out_file)

    return train_errors, test_errors


# Redundant, but this is research.
def single_train_test(
    model,
    train_loader,
    train_loss_fn,
    test_loader,
    optimizer,
    num_epochs,
    num_runs=5,
    device="cpu",
):
    train_errors, test_errors, spur_norms = [], [], []
    for run in range(1, num_runs + 1):
        model.apply(reset_weights)
        for epoch in range(1, num_epochs + 1):
            train(
                model,
                train_loader,
                train_loss_fn,
                optimizer,
                device,
            )
        train_errors.append(test(model, train_loader, device))
        test_errors.append(test(model, test_loader, device))

        first_norm = torch.linalg.norm(model[1].weight.data[0, 0].detach().cpu()).item()
        rest_norm = torch.linalg.norm(model[1].weight.data[0, 1:].detach().cpu()).item()
        spur_norms.append(first_norm / rest_norm)

    return np.array(train_errors), np.array(test_errors), np.array(spur_norms)