#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import torch.distributed as dist

from tqdm import tqdm
import sys

import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

# From custom scripts
from src import train

from src import Config
from src import get_metrics

from src import utils
from src import save_load_utils as sl
from src import plotting_functions as pf


def training(rank, world_size, conf, train_tracked_metrics=None, val_tracked_metrics=None):
    device = utils.ddp_setup(rank, world_size)
    
    distributed = world_size > 1
    if rank == 0:
        print(conf)

    # Create logger on the main process
    if device.type == 'cpu': # Log directory for laptop debugging
        LOG_DIR = './runs/'
    else: # Log location for HPC
        LOG_DIR = './'        
        
    # Logger which saves valeus directly to memory
    writer = utils.get_writer(LOG_DIR, rank, distributed)
    
    # Reproducability
    if conf.seed is not None: 
        utils.set_seed(conf.seed, deterministic=True)
        
    # Get datasets
    train_dataloader, val_dataloader, test_dataloader = train.get_loaders(conf, distributed=distributed, download=True)
    
    # Instantiate the model, optimizer, learning rate scheduler, and loss function
    model, opt, lr_scheduler, criterion = train.get_train_objs(conf, device=device)
    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])

    labels = torch.cat([y for _, y in train_dataloader])
    num_classes = len(torch.unique(labels))
    
    train_metric_collection = get_metrics(train_tracked_metrics, device, model=model, num_classes=num_classes)
    val_metric_collection   = get_metrics(  val_tracked_metrics, device, model=model, num_classes=num_classes)
    
    
    # In Distributed learning only one process should print statements
    is_main_process = not dist.is_initialized() or dist.get_rank() == 0
    # Training loop    
    pbar = tqdm(range(conf.num_epochs), leave=True, file=sys.stdout, disable=not is_main_process)
    for epoch in pbar:
        if distributed:
            train_dataloader.sampler.set_epoch(epoch)
        train.train_step(epoch, model, criterion, opt, train_dataloader, device, writer, metric_collection=train_metric_collection, distributed=distributed)
        val_loss = train.validation_step(epoch, model, criterion, val_dataloader,  device, writer, metric_collection=val_metric_collection, pbar=pbar)
        if conf.lr_scheduler == 'ReduceLROnPlateau':
            lr_scheduler.step(val_loss)
        else:
            lr_scheduler.step()
    pbar.close()
    
    utils.cleanup_distributed()
    
    # Test the model. Only display progress bar if working on cpu (i.e. debugging)
    test_acc = train.test(model, test_dataloader, device, verbose=device.type=='cpu', num_classes=num_classes)
    writer.add_scalar('Test/accuracy',test_acc)
    writer.close()

    #%% Display a breakdown of the model's sparsity per layer
    pf.display_sparsity_table(model)

    #%%% Save results
    SAVE_NAME = utils.create_name(conf)
    # Save:
    # - Checkpoint of the model, optimizer, and lr_scheduler state_dicts, and epoch
    sl.create_checkpoint(conf, model, opt, lr_scheduler, LOG_DIR=LOG_DIR, SAVE_NAME=SAVE_NAME)

    # Save:
    # - CSV of the scalar quantities from said event file for convenient access
    # - Figure visualising the scalar quantities from said event file
    sl.save_results(SAVE_NAME=SAVE_NAME, LOG_DIR=LOG_DIR)

def main(rank=0, world_size=1):
    conf = Config(**{
                 # ----- General training parameters
                 'dataset':'CIFAR10'  #For Tiny ImageNet change path in create_datasets_and_loaders and preprocess the val set before training
                 ,'batch_size' : 128
                 ,'model':'ResNet18'   #'WideResNet-28-10', 'ResNet18' or 'VGG16'
                 # Maximum number of training samples. Use None for full dataset.
                 ,'max_samples': None
                 ,'seed':0
                 # ----- Initial sparsification of model
                 ,'model_init_sparsity' : 0.99
                 ,'r':[1,5,5]
                 ,'conv_group':False # Whether to consider group sparsity for kernels
                 # ----- Global optimizer parameters
                 ,'num_epochs': 200
                 #,'optim': 'LinBreg'
                 ,'optim':'LinBregSparse' # corresponds to ML LinBreg with using the coarse model every time
                 #,'optim':'AdaBreg'
                 #,'optim': 'AdaBregSparse'
                 # ,'optim':'LinBregSparseML'
                 # ,'optim':'Adam'
                 #,'optim':'SGD'
                 ,'learning_rate': 0.1
                 ,'lr_scheduler': 'CosineAnnealing'
                 #,'lr_scheduler':'ReduceLROnPlateau'
                 ,'loss': 'CrossEntropy'
                 # ----- Proximal step parameters
                 ,'reg':'l1'
                 ,'lambda0':0.005 # Reg param for convolutional weights
                 ,'lambda1':0.005 # Reg param for linear layer weights
                 ,'delta': 1.0
                 ,'momentum':0.0
                 # ----- Parameters for sparse learning
                 ,'full_update_frequency':100
                 ,'full_update_duration':1
                 ,'full_update_mode':'step'
                 })   
    
    train_tracked_metrics = ['accuracy', 'sparsity']
    val_tracked_metrics = ['accuracy', 'sparsity', 'linear_sparsity' ,'conv_sparsity', 'layer_sparsity']
    
    # train_tracked_metrics = None
    # val_tracked_metrics = None

    training(rank, world_size, conf, train_tracked_metrics, val_tracked_metrics)


if __name__ == "__main__":
    main()
    
