### This is just an example of Algorithm 1. Please see train.py for usage.


import argparse
from datetime import datetime
from collections import defaultdict

import yaml
from tqdm import tqdm
import copy
import random

import functorch
from functorch import make_functional_with_buffers, grad

from helper import Helper
from utils.utils import *

logger = logging.getLogger('logger')


def anticipate(hlpr: Helper, epoch, model, train_loader):
    attack_model = copy.deepcopy(model)
    criterion = hlpr.task.criterion
    anticipate_steps = hlpr.params.anticipate_steps

    _, attack_params, attack_buffers = make_functional_with_buffers(attack_model)
    _, weight_names, _ = functorch._src.make_functional.extract_weights(attack_model)
    _, buffer_names, _ = functorch._src.make_functional.extract_buffers(attack_model)

    optimizer = hlpr.task.make_anticipate_optimizer(attack_params + attack_buffers, epoch=epoch)
    
    for _ in range(hlpr.params.fl_attacker_local_epochs):
        for i, data in enumerate(train_loader):
            func_model, curr_params, curr_buffers = make_functional_with_buffers(model)

            batch = hlpr.task.get_batch(i, data)
            batch = hlpr.attack.synthesizer.make_backdoor_batch(batch, attack=True)

            optimizer.zero_grad()
            loss = None

            # do anticipate_steps steps
            for anticipate_i in range(anticipate_steps):
                if anticipate_i == 0:
                    # est other users' update
                    curr_params = train_with_functorch(hlpr, epoch + anticipate_i, func_model, curr_params, curr_buffers, train_loader, num_users=hlpr.params.fl_no_models-1)

                    # add attack params at step 0
                    curr_params = [(attack_params[i] + curr_params[i] * (hlpr.params.fl_no_models - 1)) / hlpr.params.fl_no_models for i in range(len(curr_params))]
                    curr_buffers = [(attack_buffers[i] + curr_buffers[i] * (hlpr.params.fl_no_models - 1)) / hlpr.params.fl_no_models for i in range(len(curr_buffers))]
                else:
                    # do normal update
                    curr_params = train_with_functorch(hlpr, epoch + anticipate_i, func_model, curr_params, curr_buffers, train_loader, num_users=hlpr.params.fl_no_models)

                # adversarial loss
                logits = func_model(curr_params, curr_buffers, batch.inputs)
                y = batch.labels

                if loss is None:
                    loss = criterion(logits, y).mean()
                else:
                    loss += criterion(logits, y).mean()

            loss.backward()

            optimizer.step()

    # copy the params back to the model
    functorch._src.make_functional.load_weights(attack_model, weight_names, attack_params)
    functorch._src.make_functional.load_buffers(attack_model, buffer_names, attack_buffers)

    return attack_model


def train_with_functorch(hlpr, epoch, func_model, params, buffers, train_loader, num_users=1):
    lr = hlpr.params.lr * hlpr.params.gamma ** (epoch)
    criterion = hlpr.task.criterion

    def compute_loss(params, buffers, x, y):
        logits = func_model(params, buffers, x)

        loss = criterion(logits, y).mean()
        return loss

    for i, data in enumerate(train_loader):
        for _ in range(hlpr.params.fl_local_epochs):
            batch = hlpr.task.get_batch(i, data)
            grads = grad(compute_loss)(params, buffers, batch.inputs, batch.labels)

            params = [p - g * lr for p, g, in zip(params, grads)]

        break

    return params
