import csv
import pandas as pd
import numpy as np
import time
import json
import psutil


class TrainingMonitor:

    def __init__(self, csv_path, result_path, id, train_type, backbone_model):
        self.csv_path = csv_path
        self.result_path = result_path
        self.id = id
        self.train_type = train_type
        self.backbone_model = backbone_model
        self.epoch = 0
        self.start_time = time.time()
        self.train_time_dataset = {}
        self.loss = {}
        self.validation = {}
        self.dataset_list = []
        self.error = False
        self.info_text = ""
    
    def save_train_start(self):
        with open(self.csv_path, "a", newline="") as file:  # 'a' mode for appending
            writer = csv.writer(file, delimiter=";")  # Use ";" as separator
            writer.writerow([self.id, self.train_type, self.backbone_model])

    # To be executed every epoch
    def update_values(self, dataset, epoch, loss, validation=None):
        self.epoch = epoch
        if dataset in self.loss.keys():
            self.loss[dataset].append(loss)
        else: 
            self.loss[dataset] = [loss]
        if validation is not None:
            if dataset in self.validation.keys():
                self.validation[dataset].append(validation)
            else:
                self.validation[dataset] = [validation]

    
    def add_dataset(self,dataset):
        if dataset not in self.dataset_list:
            self.dataset_list.append(dataset)
    
    def add_train_time(self, dataset, start_time, end_time):
        self.train_time_dataset[dataset] = str(time.strftime("%H:%M:%S", time.gmtime(end_time- start_time)))

    def set_error(self):
        self.error = True
    def add_info_text(self, text):
        self.info_text = text

    def save_train_end(self):
        train_time = str(time.strftime("%H:%M:%S", time.gmtime(time.time() - self.start_time)))
        df = pd.read_csv(self.csv_path, delimiter=";")
        df.loc[df["ID"] == self.id, ["Epoch", "Train_Time", "Loss", "Validation", "Error", "Info"]] = [self.epoch, train_time, np.mean([item for sublist in self.loss.values() for item in sublist]), np.mean([item for sublist in self.validation.values() for item in sublist]) if self.validation else 0, self.error, self.info_text]  #[self.epoch, train_time, self.loss[self.loss.keys()[0][-1] if self.loss else 0, self.validation[-1] if self.validation else 0, self.error]
        
        for _ in range(10):
            try:
                df.to_csv(self.csv_path, sep=";", index=False)
                print("File saved successfully.")
                break
            except PermissionError:  # This error is raised if the file is in use
                print("File is being used by another process. Retrying...")
                time.sleep(10)  # Wait before retrying
        print(f"Failed to save the file.")
        
        with open(self.result_path + "train_param.json", 'w') as json_file:
            dict = {
                "info": self.info_text,
                "dataset_list": self.dataset_list,
                "loss": self.loss,
                "validation": self.validation,
                "train_time_dataset": self.train_time_dataset
                }
            json.dump(dict, json_file, indent=4)


def get_training_id(csv_path):
    df = pd.read_csv(csv_path, delimiter=";")
    if df.empty:  # Check if DataFrame is empty
        return 0
    last_id = df["ID"].iloc[-1]   # Last row, first column
    return last_id

def check_disk_usage(path):
    disk_usage = psutil.disk_usage(path)
    # Get disk usage information
    # Print out disk usage details
    print("=" * 40)
    print(f"{'STORAGE INFO':^40}")
    print(f"Total Space: {disk_usage.total / (1024**3):.2f} GB")
    print(f"Used Space: {disk_usage.used / (1024**3):.2f} GB")
    print(f"Free Space: {disk_usage.free / (1024**3):.2f} GB")
    print(f"Percentage Used: {disk_usage.percent}%")
    print("=" * 40)
    if disk_usage.free / (1024**3) < 10:
        print("!" * 80)
        print("!" * 80)
        print("LOW FREE SPACE < 10 GB")
        print("!" * 80)
        print("!" * 80)


def print_console_output(info_id, seed, dataset, epoch, model_type, backbone_model):
    print("=" * 40)
    print(f"{'INFORMATION SUMMARY':^40}")
    print("=" * 40)
    print(f"ID         : {info_id}")
    print(f"Dataset    : {dataset}")
    print(f"Seed       : {seed}")
    print("-" * 40)
    print(f"Epoch: {epoch}")
    print(f"Model_type: {model_type}")
    print(f"Backbonemodel: {backbone_model}")
    print("=" * 40)