import os
import sys
import time
import numpy as np
import torch

import wandb

from utilities.utilities import Utilities as Utils
from runners.baseRunner import baseRunner


class scratchRunner(baseRunner):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.artifact = None

        entity, project = wandb.run.entity, wandb.run.project
        self.initial_artifact_name = f"initial-{entity}-{project}-{self.config.arch}-{self.config.dataset}-{self.config.run_id}"

    def find_existing_model(self):
        """Finds an existing wandb artifact and downloads the initial model file."""
        # Create a new artifact, this is idempotent, i.e. no artifact is created if this already exists
        try:
            self.artifact = wandb.run.use_artifact(f"{self.initial_artifact_name}:latest")
            seed = self.artifact.metadata["seed"]
            self.artifact.download(root=self.tmp_dir)
            self.checkpoint_file = os.path.join(self.tmp_dir, 'initial_model.pt')
            self.seed = seed

        except Exception as e:
            print(e)

        outputStr = f"Found {self.initial_artifact_name} with seed {seed}" if self.artifact is not None else "Nothing found."
        sys.stdout.write(f"Trying to find reference initial model in project: {outputStr}\n")

    def run(self):
        """Function controlling the workflow of scratchRunner"""
        # If not existing, start a new model, otherwise use existing one with same run-id
        self.find_existing_model()

        if self.seed is None:
            # Generate a random seed
            self.seed = int((os.getpid() + 1) * time.time()) % 2 ** 32

        wandb.config.update({'seed': self.seed})  # Push the seed to wandb

        # Set a unique random seed
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        # Remark: If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use manual_seed_all().
        torch.cuda.manual_seed(self.seed)  # This works if CUDA not available

        torch.backends.cudnn.benchmark = True

        self.trainLoader, self.valLoader, self.testLoader = self.get_dataloaders()
        self.model = self.get_model(reinit=True, temporary=True)
        # Save initial model before training
        if self.artifact is None:
            self.artifact = wandb.Artifact(self.initial_artifact_name, type='model', metadata={'seed': self.seed})
            sys.stdout.write(f"Creating {self.initial_artifact_name}.\n")
            self.checkpoint_file = self.save_model(model_type='initial', temporary=True)
            self.artifact.add_file(f"{self.tmp_dir}/initial_model.pt")
            wandb.run.use_artifact(self.artifact)

        self.define_constraints()
        self.define_optimizer_scheduler()  # This was moved before define_strategy to have the optimizer available
        self.define_wd_schedule(n_schedule_epochs=self.config.n_epochs, retrain=False)
        self.strategy = self.define_strategy()
        self.strategy.after_initialization()

        self.strategy.at_train_begin()

        # Do proper training
        self.train()


        self.strategy.start_forward_mode()

        self.strategy.end_forward_mode()

        self.strategy.at_train_end()

        # Save trained (unpruned) model, to be used by pretrainedRunner
        self.checkpoint_file = self.save_model(model_type='trained')
        wandb.summary['trained_model_file'] = 'trained_model.pt'

        self.strategy.final()

        # Upload iteration-lr dict from self.strategy to be used during retraining
        Utils.dump_dict_to_json_wandb(dumpDict=self.strategy.lr_dict, name='iteration-lr-dict')
