#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 30 12:06:29 2025

Saving and loading results/checkpoints
@author: scott
"""
import torch
from tbparse import SummaryReader
import glob
import os
import yaml

from src.utils import utils
from src.utils import plotting_functions as pf
from src.utils import config

from src.models import get_model
from src.optimizers import get_opt


def latest_event_file_loc(log_dir='runs/'):
    event_files = glob.glob(os.path.join(log_dir, '**', 'events.out.tfevents.*'), recursive=True)
    if not event_files:
        raise FileNotFoundError("No event files found.")
    return max(event_files, key=os.path.getmtime)


def create_checkpoint(conf, model, opt, lr_scheduler, epoch=None, SAVE_NAME=None, LOG_DIR='./'):
    if not SAVE_NAME:
        SAVE_NAME = utils.create_name(conf)
    if epoch is None:
        epoch = conf.num_epochs
        
    latest_event_file = latest_event_file_loc(LOG_DIR)
    
    setattr(conf, 'EVENT_NAME', latest_event_file.split('/')[-1])

    conf.save(f"{LOG_DIR}{SAVE_NAME}")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'scheduler_state_dict': lr_scheduler.state_dict(),
        'epoch': epoch,
    }, f'{LOG_DIR}{SAVE_NAME}_CHECKPOINT.pth')
    print(f"\nCheckpoint created.\n--- Event file:\n{latest_event_file}\n--- Log directory:\n{LOG_DIR}\n--- Save name:\n{SAVE_NAME}")
    
    return SAVE_NAME

def save_results(SAVE_NAME=None, LOG_DIR='./'):
    latest_event_file = latest_event_file_loc(LOG_DIR)
    reader = SummaryReader(latest_event_file, pivot=True)
    
    # Access scalar data (e.g., Train/loss)
    df = reader.scalars
    pf.plot_df(df, save=f"{LOG_DIR}{SAVE_NAME}_PLOT")
    df.to_csv(f"{LOG_DIR}{SAVE_NAME}_RESULTS.csv", index=False)


def load_checkpoint(LOAD_NAME, ROOT='./', device=None, verbose=True):
    with open(f'{ROOT}{LOAD_NAME}.txt', 'r') as f:
        conf_args = yaml.safe_load(f)
    if not device:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    conf = config.Config(verbose=verbose, **conf_args)   
    checkpoint = torch.load(f'{ROOT}{LOAD_NAME}_CHECKPOINT.pth', map_location=device, weights_only=False)
    return conf, checkpoint
    
def load_states(LOAD_NAME, ROOT='./', device=None):
    "Create correct model, optimizer and lr_scheduler to continue training"
    conf, checkpoint = load_checkpoint(LOAD_NAME, ROOT=ROOT, device=device)
    
    model = get_model(conf, device, checkpoint['model_state_dict'])
    opt, lr_scheduler = get_opt(conf, model, 
                                  opt_state_dict=checkpoint['optimizer_state_dict'], 
                                  scheduler_state_dict=checkpoint['scheduler_state_dict'])

def load_results(LOAD_NAME, ROOT='./', device=None, include_model=False, verbose=False):
    "Return the configurations and tensorboard reader associated with a given run"
    conf, checkpoint = load_checkpoint(LOAD_NAME, ROOT=ROOT, device=device, verbose=verbose)
    
    reader = SummaryReader(f'{ROOT}{conf.EVENT_NAME}', pivot=True)
    if include_model:
        model = get_model(conf, checkpoint['model_state_dict'])
        return conf, reader, model
    else:
        return conf, reader