import re
import os
from datetime import datetime

import torch
import torch.nn.functional as F

import hydra

from sam import SAM
import models
from train_utils import *
from data_utils import get_dataloaders, shapes_dict

from utils.utils import *
from aus import aus
from eiil import eiil
from group_dro import GroupWeightedLoss

# os.environ['WANDB_SILENT'] = 'true'

@hydra.main(version_base=None, config_path='.', config_name='train_params')
def main(cfg):
    # Remove hydra logger
    os.remove(f'{os.path.splitext(os.path.basename(__file__))[0]}.log')
    os.umask(0)

    # Argument intake and validation
    args = DotDict(cfg)

    assert args.exp_name and args.run_name, \
        'Please specify an experiment name (exp_name) and run name (run_name).'
    if args.us_syn_only:
        assert args.us_type=='syn', \
            'Must use us_syn_only set to true with syn us_type.'
    if args.us_type=='syn':
        assert args.us_syn_dataset, \
            'Must provide a us_syn_dataset if upsampling with synthetic data.'
    assert not (args.aus and args.extra_per_epoch), \
        'Cannot add random examples to dataset if auto-upsampling.'
    assert sum([args.us, args.aus, args.crs]) in {0,1}, \
        'Can only use one upsample method per run.'
    if args.crs:
        assert args.crs_epoch_bin%2, \
            'crs_epoch_bin must be an odd number.'
    assert all(-1 < i < args.aus_clusters for i in args.downsample_cluster_indices), \
            'Downsample cluster indices out of range for number of clusters chosen.'
    assert all(-1 < i < args.aus_clusters for i in args.upsample_cluster_indices), \
            'Upsample cluster indices out of range for number of clusters chosen.'
    assert not set(args.downsample_cluster_indices) & set(args.upsample_cluster_indices), \
            'Cannot downsample and upsample the same cluster.'

    # Setup save directories (local, logs, wandb)
    args.exp_name = re.sub(r'\s+', '-', args.exp_name.strip()) if not args.debug else 'debug_runs'
    args.input_run_name = re.sub(r'\s+', '-', args.run_name.strip())
    
    date_str = datetime.now().strftime('%Y.%m.%d.%H.%M.%S')
    args.run_name = args.user + '_' + \
                    date_str + '_' + \
                    args.input_run_name + '_' + \
                    os.uname()[1][:7]
    args.local_run_path = os.path.join(args.team_path, args.exp_name, args.run_name)
    os.makedirs(args.local_run_path, mode=0o777, exist_ok=True)

    if args.aus:
        args.aus_run_name = 'aus_' + args.user + '_' + \
                        date_str + '_' + \
                        args.input_run_name
        
        args.aus_local_run_path = os.path.join(args.team_path, args.exp_name, args.aus_run_name)
        os.makedirs(args.aus_local_run_path, mode=0o777)
    elif args.eiil:
        args.eiil_run_name = 'eiil_' + args.user + '_' + \
                        date_str + '_' + \
                        args.input_run_name
        
        args.eiil_local_run_path = os.path.join(args.team_path, args.exp_name, args.eiil_run_name)
        os.makedirs(args.eiil_local_run_path, mode=0o777)

    log_std(args.local_run_path, incl_stderr=False)

    if args.debug:
        print('Running script in DEBUG mode. Wandb is DISABLED.')
        print('exp_name has been overwritten to \'debug_runs\'.\n')

    # Set environment meta-config (gpu, seed, etc.)
    if not args.accelerate:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
        print(f'Using GPU: {args.gpu}')
        print(f'GPU memory available: {(torch.cuda.get_device_properties("cuda").total_memory / 10**9):.2f} GB')

    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    set_all_seeds(args.seed)

    # Data init
    train_loader, val_loader, test_loader, num_classes = get_dataloaders(args)
    args.num_classes = num_classes

    # Model init
    kwargs = {
        'patch_size': args.patch_size,
        'hidden_size': args.hidden_dim,
        'num_hidden_layers': args.num_hidden_layers,
        'num_attention_heads': args.num_attention_heads,
        'mlp_ratio': args.mlp_ratio,
        'pretrained': args.pretrained
    }
    if not args.accelerate:
        model = models.get_model(args.model,
                                num_classes,
                                False,
                                shapes_dict[args.dataset],
                                args.model_width,
                                'relu',
                                droprate=args.droprate,
                                hidden_dim=args.hidden_dim,
                                **kwargs).to(args.device)
    else:
        model = models.get_model(args.model,
                                num_classes,
                                False,
                                shapes_dict[args.dataset],
                                args.model_width,
                                'relu',
                                droprate=args.droprate,
                                hidden_dim=args.hidden_dim,
                                **kwargs)
    
    if args.model_ckpt_path:
        print('Using weights from existing checkpoint.')
        model_dict = torch.load(args.model_ckpt_path)['last']
        model.load_state_dict({k: v for k, v in model_dict.items() if 'model_preact_hl1' not in k})
    else:
        if ('vit' not in args.model) and (not args.pretrained):
            model.apply(models.init_weights(args.model))

    criterion = lambda logits, y: F.cross_entropy(logits, y, reduction='none')
    val_criterion = None

    if args.sam_rho != 0.0:
        optimizer = SAM(model.parameters(),
                        torch.optim.SGD,
                        lr=args.lr_max,
                        momentum=args.momentum,
                        rho=args.sam_rho,
                        sam_no_grad_norm=args.sam_no_grad_norm,
                        adaptive=args.sam_adaptive)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=args.momentum)

    # Auto-Upsample
    if args.aus or args.eiil:
        epochs = args.epochs
        run_name = args.run_name
        local_run_path = args.local_run_path
        
        if args.aus:
            args.run_name = args.aus_run_name
            args.epochs = args.aus_epochs
            args.local_run_path = args.aus_local_run_path
        elif args.eiil:
            lr_max = args.lr_max
            args.lr_max = args.eiil_infer_lr
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=args.momentum)
            args.run_name = args.eiil_run_name
            args.epochs = args.eiil_infer_epoch
            args.local_run_path = args.eiil_local_run_path
            # Turn off GroupDRO loss for warmup training
            assert args.group_dro == False
            
        print(f'Warmup training for {args.epochs} epochs')

        train(args,
            model,
            criterion,
            optimizer,
            train_loader,
            val_loader,
            test_loader)
        
        if args.aus:
            train_loader = aus(args, model, criterion, train_loader)
        elif args.eiil:
            train_loader = eiil(args, model, train_loader)
            val_criterion = criterion
            criterion = GroupWeightedLoss(
                criterion, 
                num_groups=train_loader.dataset.num_groups, 
                group_weight_lr=args.group_weight_lr,
                device=args.device)
            # Turn on GroupDRO loss
            print(f"Using GroupDRO loss for {train_loader.dataset.num_groups} groups with group_weight_lr {args.group_weight_lr}")
            args.group_dro = True
            args.l2_reg = args.group_dro_weight_decay
            args.lr_max = lr_max

        if not args.aus_train_after: return

        if args.rewind_to_start:
            start_epoch = 0
        else:
            start_epoch = args.epochs
        args.model_ckpt_path = os.path.join(args.ckpt_path, f'epochs={start_epoch}.pt')
        model_dict = torch.load(args.model_ckpt_path)['last']
        model.load_state_dict({k: v for k, v in model_dict.items() if 'model_preact_hl1' not in k})

        if args.sam_rho != 0.0:
            optimizer = SAM(model.parameters(),
                            torch.optim.SGD,
                            lr=args.lr_max,
                            momentum=args.momentum,
                            rho=args.sam_rho,
                            sam_no_grad_norm=args.sam_no_grad_norm,
                            adaptive=args.sam_adaptive)
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=args.momentum)

        if args.epochs_after_aus == 'full':
            epochs = epochs
        elif args.epoch_after_aus == 'remaining':
            epochs = epochs - args.aus_epochs
        args.epochs = epochs
        args.run_name = run_name
        args.local_run_path = local_run_path
        args.aus = False

        print(f'Training from epoch {start_epoch} with upsampled data for {epochs} epochs.')

    # Train loop
    train(args,
        model,
        criterion,
        optimizer,
        train_loader,
        val_loader,
        test_loader,
        val_criterion=val_criterion)
    
if __name__ == '__main__':
    main()