import torch 
import group_DRO
from torch.utils.data import DataLoader
import torchvision
from group_DRO import run_expt
import argparse
import os
from wandb_wrapper import WandbWrapper

parser = argparse.ArgumentParser()
# Settings
parser.add_argument('-d', '--dataset', choices=["CUB", "CelebA", "BAR"], required=True)
parser.add_argument('-s', '--shift_type', choices=["confounder"], required=True)
parser.add_argument('--cuda', choices=["0", "1"], required=True)
# Confounders
parser.add_argument('-t', '--target_name')
parser.add_argument('-c', '--confounder_names', nargs='+')
# Resume?
parser.add_argument('--resume', default=False, action='store_true')
# Label shifts
parser.add_argument('--minority_fraction', type=float)
parser.add_argument('--imbalance_ratio', type=float)
# Data
parser.add_argument('--fraction', type=float, default=1.0)
parser.add_argument('--root_dir', default=None)
parser.add_argument('--reweight_groups', action='store_true', default=False)
parser.add_argument('--augment_data', action='store_true', default=False)
parser.add_argument('--val_fraction', type=float, default=1.0)
# Objective
parser.add_argument('--robust', default=False, action='store_true')
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--generalization_adjustment', default="0.0")
parser.add_argument('--automatic_adjustment', default=False, action='store_true')
parser.add_argument('--robust_step_size', default=0.01, type=float)
parser.add_argument('--use_normalized_loss', default=False, action='store_true')
parser.add_argument('--btl', default=False, action='store_true')
parser.add_argument('--hinge', default=False, action='store_true')

# Model
parser.add_argument('--model', default='resnet50')
parser.add_argument('--train_from_scratch', action='store_true', default=False)
parser.add_argument('--n_epochs', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--scheduler', action='store_true', default=False)
parser.add_argument('--weight_decay', type=float, default=5e-5)
parser.add_argument('--gamma', type=float, default=0.1)
parser.add_argument('--minimum_variational_weight', type=float, default=0)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--show_progress', default=False, action='store_true')
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--log_every', default=50, type=int)
parser.add_argument('--save_step', type=int, default=10)
parser.add_argument('--save_best', action='store_true', default=False)
parser.add_argument('--save_last', action='store_true', default=False)

## Reproducing commands
# https://worksheets.codalab.org/worksheets/0x621811fe446b49bb818293bae2ef88c0
if __name__ == "__main__":
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
    print(f"Using Device CUDA:{torch.cuda.get_device_name()}, {torch.cuda.mem_get_info()}")
    wb = WandbWrapper(project_name="GDRO_Debiasing", config=args)
    run_expt.main(args, wb)


"""
python3 gdro_debiasing.py
    -s confounder 
    -d CUB 
    -t bird 
    -c place 
    --model resnet50 
    --weight_decay 1 
    --lr 1e-05 
    --batch_size 128 
    --n_epochs 300 
    --save_step 1000 
    --save_best 
    --save_last 
    --reweight_groups 
    --robust 
    --alpha 0.01 
    --gamma 0.1 
    --generalization_adjustment 2
    --cuda 
    --log_dir
    --seed 
"""

