from options import *
import os
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
import torch
import csv
import pandas as pd


def print_root(args):
    # get the time
    t = time.strftime('%m月%d日-%H时%M分', time.localtime(time.time()))
    if args.log_root != 'none':
        os.makedirs(args.log_root + '_(' + t + ')', exist_ok=True)
        f = open(args.log_root + '_(' + t + ')/'+ args.log_name + '_(' + t + ')' + '.txt', 'a')
    else:
        f = sys.stdout
    return f, t


def picture(args, result, time):
    plt.figure()
    train_loss = plt.plot(result['epoch'], result['train']['acc'], color='red', linestyle='-.')
    test_loss = plt.plot(result['epoch'], result['test']['acc'], color='blue', linestyle='--')
    if result['val']['acc'][-1] != 0:
        val_loss = plt.plot(result['epoch'], result['val']['acc'], color='green', linestyle='-')
    plt.title('acc vs. epoch(train:red, test:blue)')

    plt.savefig(args.log_root + '_(' + time + ')/' + args.log_name + '_(' + time + ')' + ".png")


def model_save(args,model,optimizer, epoch, loss, time):
    if args.log_root != 'none':
        os.makedirs(args.log_root + '_(' + time + ')/checkpoints', exist_ok=True)
        torch.save({
            'epoch': epoch,  # Save the epoch of the model
            'model_state_dict': model.state_dict(),  # Save the parameters of the model rather than the entire model
            'optimizer_state_dict': optimizer.state_dict(),  # Save the optimizer parameters
            'task': args.task,  # Save the task
            'loss': loss
        }, args.log_root + '_(' + time + ')/checkpoints/'+ '_epoch' + str(epoch) + ".pt")
        
def value_save(result, args, time):
    if args.log_root != 'none':
        # save file
        fileName = args.log_root + '_(' + time + ')/' + args.log_name + '_(' + time + ')' + ".xlsx"
        df = pd.DataFrame()
        df['epoch']=result['epoch']
        df['train_acc']=result['train']['acc']
        df['test_acc']=result['test']['acc']
        df['val_acc']=result['val']['acc']
        df.to_excel(fileName)

def best_save(args,model,optimizer,epoch,loss,time):
    if args.log_root != 'none':
        os.makedirs(args.log_root + '_(' + time + ')/best_checkpoints', exist_ok=True)
        torch.save({
            'epoch': epoch,  # Save the epoch of the model
            'model_state_dict': model.state_dict(),  # Save the parameters of the model rather than the entire model
            'optimizer_state_dict': optimizer.state_dict(),  # Save the optimizer parameters
            'task': args.task,  # Save the task
            'loss': loss
        }, args.log_root + '_(' + time + ')/checkpoints/'+'_best_model'  + ".pt")




