from utils.classifier import step_selection
import argparse
from erasure import train
import os
import random

parser = argparse.ArgumentParser(
                    prog = 'TrainESD',
                    description = 'Finetuning stable diffusion to erase the concepts')
parser.add_argument('--save_path', help='Path to save model', type=str, required=False, default=None)
parser.add_argument('--target_concept', help='concept to erase', type=str, required=False, default=None)
parser.add_argument('--anchor_concept', help='target concept to erase from', type=str, required=False, default=None)
parser.add_argument('--device', help='cuda device to train on', type=str, required=False, default='cuda:0')
parser.add_argument('--niters', help='step selection iteration per timestep', type=int, required=False, default=30)
parser.add_argument('--lam', help='threshold of SSScore', type=float, required=False, default=0.8)
parser.add_argument('--train_method', help='Parameters to finetune(Choose from xattn_q/xattn/attn).', type=str, required=False, default=None)
parser.add_argument('--iterations', help='Number of iterations', type=int, default=200)
parser.add_argument('--lr', help='Learning rate', type=float, default=2e-5)
parser.add_argument('--negative_guidance', help='Negative guidance value', type=float, required=False, default=1)
parser.add_argument('--beta_1', help='weight of early-preserve loss', type=float, required=False, default=0.1)
parser.add_argument('--beta_2', help='weight of concept-retain loss', type=float, required=False, default=0.1)
parser.add_argument('--early_preserve_steps', help='', type=int, required=False, default=3)

args = parser.parse_args()
    
#####   Step Selection   #####
steps = step_selection(args.target_concept, args.anchor_concept, args.niters, args.device, args.lam)

#####   Training   #####
train(args.target_concept, args.anchor_concept, args.train_method, args.iterations, args.negative_guidance, args.lr, args.save_path, args.device, steps=steps, early_preserve_steps=args.early_preserve_steps, beta_1=args.beta_1, beta_2=args.beta_2)
