import json
import os
import time
from datetime import datetime

import torch
import wandb

from neural_mpm.nn import UNet, FNO, create_model


class ModelLogger:
    def __init__(
            self,
            dataset_name: str,
            run_config: dict,
            save_interval: int = 10,
            save_only_last: bool = False,
            create_wandb_json: bool = True,
            parent_dir: str = 'outputs'
    ):
        if not 'run_id' in run_config:
            try:
                self.run_name = wandb.run.name
                self.run_id = wandb.run.id
            except:
                self.run_name = datetime.now().strftime("(%d_%m) %H:%M:%S")
                self.run_id = 0
                # print(f'[Warning]: Wandb not used, setting run name to '
                #      f'{self.run_name}_{self.run_id}')
        else:
            self.run_name = run_config['run_name']
            self.run_id = run_config['run_id']

        self.dataset_name = dataset_name

        self.save_interval = save_interval
        self.save_only_last = save_only_last

        self.project_name = run_config.get('project', 'experiments')

        self.folder = (f"{parent_dir}/{self.project_name}/{self.run_name}"
                       f"_{self.run_id}")
        self.model_folder = f"{self.folder}/models"
        os.makedirs(self.folder, exist_ok=True)
        os.makedirs(self.model_folder, exist_ok=True)

        self.run_config = run_config
        self.model_name = run_config['model']
        self.model_architecture = run_config['architecture']

        # if file does not exist
        if not os.path.exists(f"{self.folder}/config.json"):
            with open(f"{self.folder}/config.json", "w") as f:
                json.dump(self.run_config, f, indent=4)

        # TODO: This should be in another function
        # This class is a bit dirty, it should be refactored
        # to allow for easily both checkpointing, saving and loading models
        if create_wandb_json:
            if self.run_id != 0:
                wandb_info = {
                    "run_name": self.run_name,
                    "run_id": self.run_id,
                    "link": f"https://wandb.ai/{wandb.run.entity}/{wandb.run.project}/runs/{self.run_id}"
                }
                with open(f"{self.folder}/wandb.json", "w") as f:
                    json.dump(wandb_info, f, indent=4)

        self.last_time = None
        self.total_time_start = None

    def start_timer(self):
        self.total_time_start = time.time()
        self.last_time = time.time()

    def save_model(
            self,
            model,
            checkpoint_name: str = None,
            json_dict: dict = None,
    ):
        """
        Save model checkpoint.
        Args:
            checkpoint_name:
            model: Model to save.
            time: Checkpoint time (in minutes).
        """

        if self.save_only_last:
            for file in os.listdir(self.model_folder):
                os.remove(f"{self.model_folder}/{file}")

        torch.save(model.state_dict(),
                   f"{self.model_folder}/{checkpoint_name}.ckpt")

        # TODO: save info about when (epoch+time) the model was saved
        # fpr best onlu
        print(f"Model saved at {self.model_folder}/{checkpoint_name}.ckpt.")
        if json_dict:
            elapsed_time = time.time() - self.total_time_start
            elapsed_time = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
            json_dict['elapsed_time'] = elapsed_time
            with open(f"{self.model_folder}/{checkpoint_name}.json", "w") as f:
                json.dump(json_dict, f, indent=4)

    def try_saving(self, model):
        """
        Try saving model checkpoint.
        Args:
            model: Model to save.
            current_time: Current time (in minutes).
        """

        if self.last_time is None:
            raise ValueError("Timer not started.")

        current_time = time.time()
        current_time_diff = current_time - self.last_time
        current_time_diff = int(current_time_diff // 60)

        if current_time_diff >= self.save_interval:
            total_time = int((current_time - self.total_time_start) // 60)

            self.save_model(model, str(total_time))
            self.last_time = current_time


    def load(self, checkpoint='best'):
        model = create_model(self.model_name, self.run_config)

        checkpoint_path = os.path.join(self.model_folder,
                                        f"{checkpoint}.ckpt")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))

        print(f'Loaded model {checkpoint}')

        return model
