import os
import torch
import json
import sys

from models import getModel
from optimizer import getOptimizer
from utility.loss import crossentropy, getEqualLoss
from utility.inputData import DataLoader
from utility.dataLogger import DataLogger
from utility.utils import initialize
from utility.LRScheduler import getLRScheduler, _LRScheduler
from utility.modelSaver import ModelSaver
from utility.args import Args

"""
run:
    python -m torch.distributed.run main.py
"""

Args.add_argument("--logDir", type=str, help="main directory to store logs")
Args.add_argument("--logSubDir", type=str, help="subdir in logDir to store logs for this run")
Args.add_argument("--epochs", type=int, help="Total number of epochs")
Args.add_argument("--contin", type=bool, help="Whether to continue from checkpoint. In continue mode parameters are read from params.json file, input file is ignored.")

Args.add_argument("--subLabels", type=int, help="")
Args.add_argument("--regFactor", type=float, help="")
Args.add_argument("--subLR", type=float, help="")
Args.add_argument("--copyDepth", type=int, help="")
Args.add_argument("--rndHeadHL", type=int, nargs = "*", help="list of random head hidden layer dimensions. (use empty list for linear probing)")

Args.add_argument("--singleHead", type=bool, help="Do we employ just a single head for RND labels? "
                                                       "If False (default), as many 'random heads' are employed as there are "
                                                        "'random classes' (labels), i.e., Args.subLabels.")

if __name__ == "__main__":
    Args.parse_args()
    if hasattr(torch, "set_float32_matmul_precision"):
        torch.set_float32_matmul_precision("high")

    logDir = os.path.join(Args.logDir, Args.logSubDir)

    if Args.contin:
        if not os.path.isdir(logDir):
            print("WARNING: '--contin' flag set, but no dir found. setting contin=false now!")
            Args.contin = False

    if Args.contin:
        with open(os.path.join(logDir, "params.json"), "r") as file:
            parameters = json.load(file)
        Args.parse_args_contin(parameters)


    torch.distributed.init_process_group(backend="nccl", init_method="env://", rank = int(os.getenv("SLURM_PROCID", -1))) #set rank to 'SLURM_PROCID' if started with slurm, else to -1
    #torch.distributed.init_process_group(backend="gloo", init_method="env://", rank = int(os.getenv("SLURM_PROCID", -1))) #set rank to 'SLURM_PROCID' if started with slurm, else to -1
    local_rank = torch.distributed.get_rank() % torch.cuda.device_count()
    torch.cuda.set_device(local_rank)

    initialize() # set up seed and cudnn

    if torch.distributed.get_rank() == 0:
        os.makedirs(logDir, exist_ok=True)
        with open(os.path.join(logDir, "params.json"), "w") as file:
            json.dump(vars(Args.data), file, indent = 4)


    dataLogger = DataLogger()
    dataset = DataLoader()

    model = getModel()(num_classes=dataset.numClasses)
    model = model.cuda(local_rank)
    if hasattr(torch, "compile") and int(sys.version.split(".")[1]) < 11: #compile not available for python 11 yet
        model = torch.compile(model)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=False)

    optimizer = getOptimizer()(model.parameters())
    lrScheduler: _LRScheduler = getLRScheduler(optimizer)
    modelSaver = ModelSaver(model = model, optimizer = optimizer)
    
    startEpoch = 1
    if Args.contin:
        startEpoch = modelSaver.loadModel("checkpoint.model")
        startEpoch += 1
        model = model.cuda(local_rank)
        if startEpoch >= Args.epochs:
            raise RuntimeError(f"Can't continue model from epoch {startEpoch} to max epoch {Args.epochs}.")
    else:
        modelSaver(0)

    equalLossFunc = getEqualLoss()
    subPredictionParams = [param for name, param in model.named_parameters() if "subPredictionsNet" in name]
    #nonSubPredictionParams = [ param for name, param in model.named_parameters() if not "subPredictionsNet" in name ]

    torch.distributed.barrier() #wait until all workers are done with initialization

    dataLogger.printHeader()
    state = {
        "model": model,
        "lrScheduler": lrScheduler,
        "optimizer": optimizer,
    }

    for epoch in range(startEpoch, Args.epochs+1):
        dataset.train.sampler.set_epoch(epoch)

        model.train()

        numBatches = len(dataset.train)
        dataLogger.startTrain(trainDataLen = numBatches)

        for i, batch in enumerate(dataset.train):
            lrScheduler.step(epoch-1, (i+1)/numBatches)
            inputs, fullTargets = (b.cuda(local_rank) for b in batch)
            targets = fullTargets[:,0]
            subTargets = fullTargets[:,1] if Args.singleHead else fullTargets[:,1] + Args.subLabels*fullTargets[:,0]

            subPredictions, predictions = model(inputs)

            loss = crossentropy(predictions, targets, smoothing = Args.label_smoothing)
            subLoss = crossentropy(subPredictions, subTargets)
            
            if Args.grad_clip_norm != 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), Args.grad_clip_norm)

            gradients = torch.autograd.grad(Args.subLR*subLoss.mean(), subPredictionParams, retain_graph=True)
            for param, grad in zip(subPredictionParams, gradients):
                if param.grad is None:
                    param.grad = grad
                else:
                    param.grad += grad
                param.requires_grad = False

            equalLoss = equalLossFunc(subPredictions, subTargets = subTargets, targets = targets)
            finalLoss = loss.mean() + Args.regFactor * equalLoss.mean()
            finalLoss.backward()

            for param in subPredictionParams:
                param.requires_grad = True

            optimizer.step()
            optimizer.zero_grad()

            with torch.no_grad():
                state["loss"] = loss
                state["predictions"] = predictions
                state["subPredictions"] = subPredictions
                state["equalLoss"] = equalLoss
                state["subLoss"] = subLoss
                state["targets"] = targets
                state["subTargets"] = subTargets
                dataLogger(state)

        dataLogger.flush()
        
        dataLogger.startTest()
        model.eval()
        with torch.no_grad():
            for batch in dataset.test:
                inputs, fullTargets = (b.cuda(local_rank) for b in batch)
                targets = fullTargets[:,0]
                subTargets = fullTargets[:,1]+Args.subLabels*fullTargets[:,0]
                
                subPredictions, predictions = model(inputs)
                loss = crossentropy(predictions, targets, smoothing = Args.label_smoothing)
                state["loss"] = loss
                state["predictions"] = predictions
                state["targets"] = targets
                dataLogger(state)

            dataLogger.flush()
            modelSaver(epoch)

    dataLogger.printFooter()
    torch.distributed.destroy_process_group()
