import os

import numpy as np
import torch
from tqdm import tqdm

import wandb
from attack.attacker import Attack
from attack.evaluate import eval_attack
from utils import convert_dict_to_json_types, dump_json


def train(model, dataloader, args, device):
    # get reference point
    x, y = iter(dataloader).__next__()

    # assume len(y) > 1
    x_ref = x[: args.N_ref]
    y_ref = y[: args.N_ref]

    if args.type == "supervised":
        x_trg = x[args.N_ref : args.N_ref + args.N_trg]
    else:
        x_trg = None

    # init attacker
    attacker = Attack(model, args, device)

    # train adversarial samples
    x_ref = x_ref.to(device)
    y_ref = y_ref.to(device)

    # For classifier attack we may have more than 2 task
    # (e.g. class and color for color mnist)
    # (Note clf has not been implemented yet)
    if "clf" in args.loss_type:
        task_loop = range(len(args.model.n_classes))
    else:
        task_loop = [None]

    task_logs = []
    task_total_logs = []
    for task in task_loop:
        logs, total_logs = train_fn(attacker, model, x_ref, y_ref, args, x_trg, task)
        task_logs.append(logs)
        task_total_logs.append(total_logs)
    
    return task_logs, task_total_logs


def train_fn(attacker, model, x_ref, y_ref, args, x_trg=None, task=None):
    single_image_logs = []
    total_logs = {}
    # loop over reference images
    for step, (xi, yi) in tqdm(enumerate(zip(x_ref, y_ref)), total=len(x_ref), desc="Attacking every single image", leave=False):
        xi = xi.unsqueeze(0)
        # yi = yi.unsqueeze(0)

        x_adv = attacker.get_attack(xi, all_trg=x_trg, task=task)
        x_adv = torch.cat(x_adv)
        logs = eval_attack(
            model,
            xi,
            x_adv,
            step,
            x_trg=x_trg,
            task=task,
            save_dir=args.save_dir,
        )

        for k in logs.keys():
            if not isinstance(logs[k], wandb.Image):
                if "Av_" + k not in total_logs.keys():
                    total_logs["Av_" + k] = 0.0
                total_logs["Av_" + k] += logs[k]

        # add task num
        if task is not None:
            for k in list(logs.keys()):
                logs[f"task{task}/{k}"] = logs.pop(k)
        if args.wandb:
            wandb.log(logs)

        if task is not None:
            pref = f"task{task}/"
        else:
            pref = ""

        if args.wandb:
            for k in total_logs:
                wandb.run.summary[pref + k] = total_logs[k] / (step + 1)
        
        single_image_logs.append(convert_dict_to_json_types(logs))

    if args.wandb:
        for k in ["Av_ref_acc", "Av_adv_acc"]:
            if k not in total_logs.keys() and k + "_0" in total_logs.keys():
                wandb.run.summary[k] = np.mean(
                    [
                        total_logs[k + f"_{i}"] / (step + 1)
                        for i in range(len(model.classifier))
                    ]
                )

    total_logs_average = dict()
    for k, v in total_logs.items():
        total_logs_average[k] = v / (step + 1)

    return single_image_logs, convert_dict_to_json_types(total_logs_average)
