import math
import numpy as np
import torch

from dataset import get_data_name, max_margin_predictor


PRINT_FREQ = 32
EVAL_FREQ = 4


def train(
    input_d,
    client_samples,
    num_clients,
    rounds,
    interval,
    lr,
    bias,
    two_stage,
    data_type,
    local_var,
    mu_het,
    label_het,
    quiet
):

    # Load data.
    x_name, y_name = get_data_name(
        num_clients, local_var, input_d, client_samples, mu_het, label_het, data_type
    )
    local_xs = np.load(x_name)
    local_ys = np.load(y_name)

    # Convert data to tensor format.
    local_xs = torch.Tensor(local_xs)
    local_ys = torch.LongTensor(local_ys)

    # Initialize model, loss and optimizer.
    w = torch.nn.Linear(input_d, 1, bias=bias)
    prev_w = torch.nn.Linear(input_d, 1, bias=bias)
    next_w = torch.nn.Linear(input_d, 1, bias=bias)
    criterion = torch.nn.SoftMarginLoss()
    optimizer = torch.optim.SGD(w.parameters(), lr=lr)

    # Parse two stage options.
    if two_stage is not None:
        comma_pos = two_stage.find(",")
        stage1_len = int(two_stage[:comma_pos])
        stage2_lr = float(two_stage[comma_pos+1:])

    # Initialize results.
    results = {
        "rounds": [],
        "loss": {},
        "acc": {},
    }

    # Train.
    for r in range(rounds):

        prev_w.load_state_dict(next_w.state_dict())
        next_w.weight.data = torch.zeros_like(next_w.weight.data)
        if next_w.bias is not None:
            next_w.bias.data = torch.zeros_like(next_w.bias.data)

        # Transition to second stage learning rate, if necessary.
        if two_stage is not None and r == stage1_len:
            for g in optimizer.param_groups:
                g["lr"] = stage2_lr

        # Train each local model.
        for client in range(num_clients):
            w.load_state_dict(prev_w.state_dict())

            # Run GD for `interval` steps.
            for k in range(interval):
                optimizer.zero_grad()
                predictions = w(local_xs[client]).squeeze()
                loss = criterion(predictions, local_ys[client])
                loss.backward()
                optimizer.step()

            next_w.weight.data += w.weight.data / num_clients
            if w.bias is not None:
                next_w.bias.data += w.bias.data / num_clients

        # Evaluate model.
        if r % EVAL_FREQ == 0 or r == rounds - 1:
            with torch.no_grad():

                current_loss = 0
                current_acc = 0
                for client in range(num_clients):
                    predictions = next_w(local_xs[client]).squeeze()
                    loss = float(criterion(predictions, local_ys[client]))
                    acc = float(torch.sum(predictions * local_ys[client] > 0))
                    acc /= int(predictions.numel())
                    current_loss += loss / num_clients
                    current_acc += acc / num_clients

            results["rounds"].append(r)
            results["loss"][r] = current_loss
            results["acc"][r] = current_acc

        if not quiet and (r % PRINT_FREQ == 0 or r == rounds - 1):
            print(f"Round {r}:")
            print(f"    loss: {current_loss:.5f}")
            print(f"    acc:  {current_acc:.5f}")


    return results
