import os
import torch
import torch.optim as optim
import utils
from utils import mkdir_p, parse_args
from utils import get_lr, save_checkpoint, create_save_path
from solvers.runners import train, test
from solvers.loss import loss_dict
from models import model_dict
from datasets import dataloader_dict, dataset_nclasses_dict, dataset_classname_dict
from time import localtime, strftime
import logging
import numpy as np
import random
import torch.backends.cudnn as cudnn

if __name__ == "__main__":
    args = parse_args()
    utils.init_distributed_mode(args)

    # Determine rank from environment variables
    args.rank = int(os.environ.get("RANK", 0))
    args.gpu = int(os.environ.get("LOCAL_RANK", 0)) 

    # Set up logging early based on rank
    if args.rank == 0:
        # Create the save path directory if needed
        current_time = strftime("%d-%b", localtime())
        model_save_pth = f"{args.checkpoint}/{args.dataset}/{current_time}{create_save_path(args)}"
        if not os.path.isdir(model_save_pth):
            mkdir_p(model_save_pth)

        logging.basicConfig(
            level=logging.INFO,
            format="%(levelname)s:  %(message)s",
            handlers=[
                logging.FileHandler(os.path.join(model_save_pth, "train.log")),
                logging.StreamHandler()
            ],
            force=True
        )
        logging.info(f"Rank {args.rank}: Logging initialised on main process.")
        logging.info(f"Setting up logging folder : {model_save_pth}")
    else:
        logging.basicConfig(level=logging.ERROR)
    
    # For reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    cudnn.benchmark = True

    device = torch.device(args.device)
    # current_time = strftime("%d-%b", localtime())
    # model_save_pth = f"{args.checkpoint}/{args.dataset}/{current_time}{create_save_path(args)}"
    
    num_classes = dataset_nclasses_dict[args.dataset]
    if args.rank == 0:
        logging.info(f"Using model : {args.model}")
    model = model_dict[args.model](num_classes=num_classes)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.rank == 0:
        logging.info(f"Using dataset : {args.dataset}")
    trainloader, valloader = dataloader_dict[args.dataset](args)

    if args.rank == 0:
        logging.info(f"Setting up optimizer : {args.optimizer}")
    
    if args.optimizer == "sgd":
        optimizer = optim.SGD(model_without_ddp.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    elif args.optimizer == "adam":
        optimizer = optim.Adam(model_without_ddp.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    
    if args.loss == 'OLS':
        criterion = loss_dict[args.loss](num_classes, device)
    else:
        criterion = loss_dict[args.loss](gamma=args.gamma, alpha=args.alpha, beta=args.beta, loss=args.loss, delta=args.delta)
    
    test_criterion = loss_dict["cross_entropy"]()

    if args.rank == 0:
        logging.info(f"Step sizes : {args.schedule_steps} | lr-decay-factor : {args.lr_decay_factor}")
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.schedule_steps, gamma=args.lr_decay_factor)

    start_epoch = args.start_epoch
    best_acc = 0.
    best_acc_stats = {"top1": 0.0}

    for epoch in range(start_epoch, args.epochs):
        if args.rank == 0:
            logging.info('Epoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, get_lr(optimizer)))
        
        if args.loss == 'OLS':
            criterion.reset_epoch_state()

        train_loss, top1_train = train(trainloader, model, optimizer, criterion)

        if args.loss == 'OLS':
            criterion.normalise_loss_lams()

        val_loss, top1, top3, top5, sce_score, ece_score, aece_score = test(valloader, model, test_criterion)
      
        scheduler.step()

        if args.rank == 0:
            logging.info("End of epoch {} stats: train_loss : {:.4f} | val_loss : {:.4f} | top1_train : {:.4f} | top1 : {:.4f} | SCE : {:.5f} | ECE : {:.5f} | AECE : {:.5f}".format(
                epoch + 1,
                train_loss,
                val_loss,
                top1_train,
                top1,
                sce_score,
                ece_score,
                aece_score
            ))

            is_best = top1 > best_acc
            best_acc = max(best_acc, top1)
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'dataset': args.dataset,
                'model': args.model
            }, is_best, checkpoint=model_save_pth)
        
            if is_best:
                best_acc_stats = {
                    "top1": top1,
                    "top3": top3,
                    "top5": top5,
                    "SCE": sce_score,
                    "ECE": ece_score,
                    "AECE": aece_score
                }

    if args.rank == 0:
        logging.info("Training completed...")
        logging.info("The stats for best trained model on test set are as below:")
        logging.info(best_acc_stats)


    