import datetime
import numpy as np
import os
import torch
from dataset import *
from inference import *
from sample import *
from torch import nn
from train import *
from utils import *


# parse arguments
parser = get_parser()
args = parser.parse_args()

# set gpu
# if cuda_visible_devices is not set, set it to the first gpu
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in args.gpu])
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# set class number
if args.dataset in ['waterbirds', 'celeba']:
    args.num_classes = 2
    if args.dataset == 'waterbirds':
        args.epochs = 300
    elif args.dataset == 'celeba':
        args.epochs = 50
elif args.dataset in ['multinli', 'civilcomments']:
    if args.dataset == 'multinli':
        args.num_classes = 3
    else:
        args.num_classes = 2
    args.epochs = 5
    args.optimizer = 'adamw'
elif args.dataset == 'cmnist':
    args.num_classes = 5

# set save directory
args.checkpoint_path = os.path.join(args.checkpoint_path, f'{args.dataset}_{args.arch}', f'{args.infer_optimizer}_{args.infer_lr:.0e}_{args.infer_weight_decay:.0e}')
if args.infer_augment:
    args.checkpoint_path += '_aug'
if args.balance_classes_infer:
    args.checkpoint_path +=  f'_balance'
if args.class_dro_infer:
            args.checkpoint_path += f'_classdro'
args.save_dir = os.path.join(args.checkpoint_path, str(args.seed))

# set logger
if not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir)

logger = logging.getLogger(__name__)
args.logger = set_logger(args, logger)
args.logger.info(str(args) + ' ' + str(datetime.datetime.now()) + '\n')

# load data
train_dataset = load_dataset(args, split='train', augment=args.infer_augment)
args.train_size = len(train_dataset)
args.save_freq = args.epochs // 5 if args.save_unit == 'epoch' else args.train_size // args.batch_size

# log group size
group_sizes = np.bincount(train_dataset.group_array)
args.logger.info(f'Group sizes: {group_sizes}')

train_criterion = nn.CrossEntropyLoss(reduction='none')
val_criterion = nn.CrossEntropyLoss()

# load validation data
val_loader = [torch.utils.data.DataLoader(
        load_dataset(args, split='val', group=group),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True) 
        for group in [[i, j] for i in range(args.num_classes) for j in range(len(group_sizes)//args.num_classes)]]

test_loader = [torch.utils.data.DataLoader(
        load_dataset(args, split='test', group=group),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True) 
        for group in [[i, j] for i in range(args.num_classes) for j in range(len(group_sizes)//args.num_classes)]]

# train model on all data for a few epochs
if len(args.infer_steps) > 0:

    for infer_step in args.infer_steps:
        # load checkpoint if available
        train_refer_model(args, train_criterion, val_criterion, val_loader, test_loader, infer_step)
