import re
import os
from datetime import datetime

import torch
import torch.nn.functional as F

import hydra

from sam import SAM
import models
from train_utils import *
from data_utils import get_dataloaders, shapes_dict

from utils.utils import *
from aus import aus

#os.environ['WANDB_SILENT'] = 'true'

@hydra.main(version_base=None, config_path='.', config_name='train_params')
def main(cfg):
    # Remove hydra logger
    os.remove(f'{os.path.splitext(os.path.basename(__file__))[0]}.log')
    os.umask(0)

    # Argument intake and validation
    args = DotDict(cfg)

    assert args.exp_name and args.run_name, \
        'Please specify an experiment name (exp_name) and run name (run_name).'
    if args.us_syn_only:
        assert args.us_type=='syn', \
            'Must use us_syn_only set to true with syn us_type.'
    if args.us_type=='syn':
        assert args.us_syn_dataset, \
            'Must provide a us_syn_dataset if upsampling with synthetic data.'
    assert not (args.aus and args.extra_per_epoch), \
        'Cannot add random examples to dataset if auto-upsampling.'
    assert not (args.us and args.aus), \
        'Cannot manually upsample and auto-upsample in the same run.'

    # Setup save directories (local, logs, wandb)
    args.exp_name = re.sub(r'\s+', '-', args.exp_name.strip()) if not args.debug else 'debug_runs'
    args.input_run_name = re.sub(r'\s+', '-', args.run_name.strip())
    
    date_str = datetime.now().strftime('%Y.%m.%d.%H.%M.%S')
    args.run_name = args.user + '_' + \
                    date_str + '_' + \
                    args.input_run_name
    args.local_run_path = os.path.join(args.team_path, args.exp_name, args.run_name)
    os.makedirs(args.local_run_path, mode=0o777)

    if args.aus:
        args.aus_run_name = 'aus_' + args.user + '_' + \
                        date_str + '_' + \
                        args.input_run_name
        
        args.aus_local_run_path = os.path.join(args.team_path, args.exp_name, args.aus_run_name)
        os.makedirs(args.aus_local_run_path, mode=0o777)

    log_std(args.local_run_path, incl_stderr=False)

    if args.debug:
        print('Running script in DEBUG mode. Wandb is DISABLED.')
        print('exp_name has been overwritten to \'debug_runs\'.\n')

    # Set environment meta-config (gpu, seed, etc.)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    print(f'Using GPU: {args.gpu}')
    print(f'GPU memory available: {(torch.cuda.get_device_properties("cuda").total_memory / 10**9):.2f} GB')

    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    set_all_seeds(args.seed)

    # Data init
    train_loader, val_loader, test_loader, num_classes = get_dataloaders(args)
    args.num_classes = num_classes

    # Model init
    model = models.get_model(args.model,
                             num_classes,
                             False,
                             shapes_dict[args.dataset],
                             args.model_width,
                             'relu',
                             droprate=args.droprate).to(args.device)
    
    if args.model_ckpt_path:
        print('Using weights from existing checkpoint.')
        model_dict = torch.load(args.model_ckpt_path)['last']
        model.load_state_dict({k: v for k, v in model_dict.items() if 'model_preact_hl1' not in k})
    else:
        model.apply(models.init_weights(args.model))

    criterion = lambda logits, y: F.cross_entropy(logits, y, reduction='none')

    if args.sam_rho != 0.0:
        optimizer = SAM(model.parameters(),
                        torch.optim.SGD,
                        lr=args.lr_max,
                        momentum=args.momentum,
                        rho=args.sam_rho,
                        sam_no_grad_norm=args.sam_no_grad_norm)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=args.momentum)

    # Auto-Upsample
    if args.aus:
        print('Training for auto-upsampling...')

        epochs = args.epochs
        args.epochs = args.aus_epochs

        run_name = args.run_name
        args.run_name = args.aus_run_name

        local_run_path = args.local_run_path
        args.local_run_path = args.aus_local_run_path

        train(args,
            model,
            criterion,
            optimizer,
            train_loader,
            val_loader,
            test_loader)
        
        train_loader = aus(args, model, criterion, train_loader)

        if not args.aus_train_after: return

        if args.rewind_to_start:
            start_epoch = 0
        else:
            start_epoch = args.aus_epochs
        args.model_ckpt_path = os.path.join(args.ckpt_path, f'epochs={start_epoch}.pt')
        model_dict = torch.load(args.model_ckpt_path)['last']
        model.load_state_dict({k: v for k, v in model_dict.items() if 'model_preact_hl1' not in k})

        if args.sam_rho != 0.0:
            optimizer = SAM(model.parameters(),
                            torch.optim.SGD,
                            lr=args.lr_max,
                            momentum=args.momentum,
                            rho=args.sam_rho,
                            sam_no_grad_norm=args.sam_no_grad_norm)
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=args.momentum)

        if args.epochs_after_aus == 'full':
            epochs = epochs
        elif args.epoch_after_aus == 'remaining':
            epochs = epochs - args.aus_epochs
        args.epochs = epochs
        args.run_name = run_name
        args.local_run_path = local_run_path
        args.aus = False

        print(f'Training from epoch {start_epoch} with upsampled data for {epochs} epochs.')

    # Train loop
    train(args,
        model,
        criterion,
        optimizer,
        train_loader,
        val_loader,
        test_loader)
    
if __name__ == '__main__':
    main()