import os
import time
import torch
import logging
import shutil
import hydra
import pretty_errors
import argparse
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

from omegaconf import OmegaConf
from easydict import EasyDict
from tqdm import tqdm
# from tqdm.auto import tqdm
from einops import rearrange
from train import train_epoch
from eval import eval_clean
from utils import setup_logging, set_seed, accuracy, AverageMeter, get_dataset, get_model


# config pretty_errors
pretty_errors.configure(
    separator_character = '*',
    filename_display    = pretty_errors.FILENAME_EXTENDED,
    line_number_first   = True,
    display_link        = True,
    lines_before        = 5,
    lines_after         = 2,
    line_color          = pretty_errors.RED + '> ' + pretty_errors.default_config.line_color,
    code_color          = '  ' + pretty_errors.default_config.line_color,
    truncate_code       = True,
    display_locals      = True,
)


def save_src_for_reproduce(configs, out_dir):
    if not os.path.exists(os.path.join(out_dir, 'src')):
        os.makedirs(os.path.join(out_dir, 'src'))
        # shutil.rmtree(os.path.join('outputs', out_dir, 'src'))
    # shutil.copytree('models', os.path.join('outputs', out_dir, 'src', 'models'))
    # dump config to yaml file
    OmegaConf.save(dict(configs), os.path.join(out_dir, 'src', 'config.yaml'))
    

@hydra.main(version_base=None, config_path='config', config_name='NT')
def main(configs):
    set_seed(42)

    configs = EasyDict(configs)
    save_src_for_reproduce(configs, configs.TRAIN.out_dir)
    
    set_seed(configs.TRAIN.seed)
    # lr_drop = list(map(int, configs.TRAIN.lr_drop.split(',')))
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    torch.backends.cudnn.benchmark = True
    
    # model and dataloader
    train_loader, test_loader, norm_layer = get_dataset(configs.dataset_cfg, configs.TRAIN.normalize)
    classifier = get_model(configs.dataset_cfg.classifier, configs.dataset_cfg.num_classes)
    classifier = classifier.to(device)
    # generator = get_model(configs.dataset_cfg.generator, configs.dataset_cfg.num_classes)
    # generator = generator.to(device)
    
    # tensorboard
    tb_dir = configs.TRAIN.tb_dir
    if not os.path.exists(tb_dir):
        os.makedirs(tb_dir)
    writer = SummaryWriter(tb_dir)
    
    # tensorboard
    ckpt_dir = configs.TRAIN.ckpt_dir
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)


    # # optimizer
    # if configs.TRAIN.l2:
    #     decay, no_decay = [], []
    #     for name, param in classifier.named_parameters():
    #         if 'bn' not in name and 'bias' not in name:
    #             decay.append(param)
    #         else:
    #             no_decay.append(param)
    #     params = [{'params': decay, 'weight_decay':configs.TRAIN.wd},
    #             {'params': no_decay, 'weight_decay': 0}]
    # else:
    params = classifier.parameters()
    optimizer = optim.SGD(params, lr=configs.TRAIN.lr, momentum=configs.TRAIN.momentum, weight_decay=configs.TRAIN.wd)

    # scheduler
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=configs.TRAIN.lr_drop, gamma=configs.TRAIN.gamma)
    # loss function
    criterion = torch.nn.CrossEntropyLoss()
    
    print(f"Start experiment: {configs.TRAIN.out_dir}")
    n_params = sum([p.numel() for p in classifier.parameters()])
    print(f"No. of parameters: {n_params}")

    # train
    process_bar = tqdm(range(configs.TRAIN.epoches))
    best_acc = 0
    for epoch in process_bar:
        train_acc1, train_loss = train_epoch(classifier, train_loader, criterion, optimizer, norm_layer, device)
        scheduler.step()
        lr = optimizer.state_dict()['param_groups'][0]['lr']

        writer.add_scalar('Train/accuracy', train_acc1, epoch)
        writer.add_scalar('Train/loss', train_loss, epoch)
        writer.add_scalar('Train/lr', lr, epoch)

        # Compute the accuracy on the val set and record
        eval_clean_acc1 = eval_clean(classifier, test_loader, norm_layer, device)
        writer.add_scalar('Eval/accuracy', eval_clean_acc1, epoch)
        # udpate progress bar
        print(f"Epoch: {epoch :d}, acc1: {train_acc1 :.2f}, test acc1: {eval_clean_acc1 :.2f}, loss: {train_loss:.4f}")
        process_bar.set_description(f"Epoch: {epoch :d}, acc1: {train_acc1 :.2f}, test acc1: {eval_clean_acc1 :.2f}, loss: {train_loss:.4f}")
        
        if (epoch+1) % configs.TRAIN.save_interval == 0:
            torch.save(classifier.state_dict(), os.path.join(configs.TRAIN.ckpt_dir, "epoch_" + str(epoch) + ".pth"))
        
        if best_acc <= eval_clean_acc1:
            best_acc = eval_clean_acc1
            torch.save(classifier.state_dict(), os.path.join(configs.TRAIN.ckpt_dir, "best.pth"))
            
    print("Training finished!")
    print(f"Best acc1: {best_acc:.4f}")
        
    
    writer.flush()
    writer.close()
    
if __name__=='__main__':
    main()


