import math
import os
import time
import matplotlib.pyplot as plt
import numpy as np

class logfile:
    def __init__(self, path='./logFile/', repalce=False):
        self.path = path
        self.log_path = path + 'log.txt'
        self.info_path = path + 'infoLog.txt'
        self.warn_path = path + 'warningLog.txt'
        self.error_path = path + 'errorLog.txt'
        self.debug_path = path + 'debugLog.txt'
        self.impt_path = path + 'imptLog.txt'
        self.log_cnt = 0
        self.info_cnt = 0
        self.warn_cnt = 0
        self.error_cnt = 0
        self.debug_cnt = 0
        self.impt_cnt = 0
        self.showlv = 0
        self.loglv = 0
        self.showtime = True
        if not os.path.exists(path):
            os.mkdir(path)

        for p in [self.log_path,self.info_path,self.warn_path,self.error_path,self.debug_path,self.impt_path]:
            if os.path.exists(p):
                if repalce:
                    os.system(f'rm -rf {p}')
                    print(f"*{p} exist* : deleted previous")
                else:
                    print(f"*{p} exist* : continue")

    def config(self, showlv: int, loglv: int, showtime: bool):
        self.showlv = showlv
        self.loglv = loglv
        self.showtime = showtime

    def log(self, text: str, tm: str):
        log = f"Log({self.log_cnt}) : {text} \n"
        if tm is not None:
            log = f"|{tm}|" + log
        self.log_cnt += 1
        with open(self.log_path, 'a+') as f:
            f.write(log)

    def info(self, text: str, end='\n'):
        log = f"Info({self.info_cnt}) : {text}"
        tm = None
        if self.showtime:
            tm = time.strftime("%Y_%m_%d-%H-%M-%S", time.localtime()).split('_')[-1]
            log = f'|{tm}|' + log
        self.info_cnt += 1
        if self.showlv <= 0:
            print(log, end=end)
        if self.loglv <= 0:
            with open(self.info_path, 'a+') as f:
                f.write(log + '\n')
        self.log(text, tm)

    def warning(self, text: str, end='\n'):
        log = f"Warning({self.warn_cnt}) : {text}"
        tm = None
        if self.showtime:
            tm = time.strftime("%Y_%m_%d-%H-%M-%S", time.localtime()).split('_')[-1]
            log = f'|{tm}|' + log
        self.warn_cnt += 1
        if self.showlv <= 1:
            print(log, end=end)
        if self.loglv <= 1:
            with open(self.warn_path, 'a+') as f:
                f.write(log + '\n')
        self.log(text, tm)

    def error(self, text: str, end='\n'):
        log = f"Error({self.error_cnt}) : {text}"
        tm = None
        if self.showtime:
            tm = time.strftime("%Y_%m_%d-%H-%M-%S", time.localtime()).split('_')[-1]
            log = f'|{tm}|' + log
        self.error_cnt += 1
        if self.showlv <= 2:
            print(log, end=end)
        if self.loglv <= 2:
            with open(self.error_path, 'a+') as f:
                f.write(log + '\n')
        self.log(text, tm)

    def debug(self, text: str, end='\n'):
        log = f"Debug({self.debug_cnt}) : {text}"
        tm = None
        if self.showtime:
            tm = time.strftime("%Y_%m_%d-%H-%M-%S", time.localtime()).split('_')[-1]
            log = f'|{tm}|' + log
        self.debug_cnt += 1
        if self.showlv <= 2:
            print(log, end=end)
        if self.loglv <= 2:
            with open(self.debug_path, 'a+') as f:
                f.write(log + '\n')
        self.log(text, tm)

    def clas(self, text: str, end='\n', tag: str = 'default'):
        log = f"Info({self.info_cnt}) : {text}"
        tm = None
        if self.showtime:
            tm = time.strftime("%Y_%m_%d-%H-%M-%S", time.localtime()).split('_')[-1]
            log = f'|{tm}|' + log
        with open(''.join([self.path, tag, '.txt']), 'a+') as f:
            f.write(log + '\n')
        self.impt(text, end=end, tm=tm)

    def impt(self, text: str, end='\n', tm=None):
        log = f"Impt({self.impt_cnt}) : {text} \n"
        if self.showtime and tm is None:
            tm = time.strftime("%Y_%m_%d-%H-%M-%S", time.localtime()).split('_')[-1]
            log = f'|{tm}|' + log
        self.impt_cnt += 1
        if self.showlv <= 1:
            print(log, end=end)
        if self.loglv <= 1:
            with open(self.impt_path, 'a+') as f:
                f.write(log + '\n')
        self.log(text, tm)


class timer:
    def __init__(self, mode_default='w'):
        """
        :param mode_default:
            's' for second
            'm' for minute
            'h' for hour
            'd' for day
            'w' for wise
        """
        t = time.time()
        self.timers = {}
        self.ckpt = [t]
        self.ckpt_switch = True
        self.mode_default = mode_default
        self.p = 0.9

    def time_trans(self, t, mode=None):
        if mode is None: mode = self.mode_default
        if mode == 's':
            return f"{t:.2f}s"
        t = t / 60
        if mode == 'm' or (mode == 'w' and t < 60):
            return f"{t:.2f}m"
        t = t / 60
        if mode == 'h' or (mode == 'w' and t < 24):
            return f"{t:.2f}h"
        t = t / 24
        return f"{t:.4f}d"

    def timer_set(self, name: str, total_step: int, cur_step: int = 0):
        self.timers[name] = [time.time(), total_step, cur_step, None]

    def time_ckpt(self, name, cur_step=None, mode=None):
        last, total_step, last_step, speed = self.timers[name]
        if cur_step is None: cur_step = last_step + 1
        assert cur_step > last_step
        now = time.time()
        if speed is None:
            p = 0
            speed = 0
        else:
            p = self.p
        speed = p * speed + (1 - p) * (now - last) / (cur_step - last_step)
        res_time = speed * (total_step - cur_step)
        self.timers[name] = [now, total_step, cur_step, speed]
        return self.time_trans(res_time, mode), speed, cur_step, total_step

    def train_init(self, epoch, step, eval_step):
        self.train_epoch = epoch
        self.train_step = step
        self.train_eval_step = epoch * math.ceil(step / eval_step)
        self.train_step_speed = None
        self.train_eval_speed = None
        self.train_cur_step = 0
        self.train_cur_eval = 0
        self.train_cur_epoch = 0

    def train_epoch_ckpt(self):
        self.train_step_last = time.time()
        self.train_cur_epoch += 1
        self.train_cur_step = 0

    def train_step_ckpt(self):
        self.train_cur_step += 1
        t = time.time()
        if self.train_step_speed is None or self.train_cur_step <= 5:
            p = 0
            self.train_step_speed = 0
        else:
            p = self.p
        self.train_step_speed = p * self.train_step_speed + (1 - p) * (t - self.train_step_last)
        self.train_step_last = t
        return self.train_cur_step

    def train_eval_ckpt(self, show=False, mode=None):
        self.train_cur_eval += 1
        t = time.time()
        if self.train_eval_speed is None:
            p = 0
            self.train_eval_speed = 0
        else:
            p = self.p
        self.train_eval_speed = p * self.train_eval_speed + (1 - p) * (t - self.train_step_last)
        self.train_step_last = t
        if show:
            return self.train_show(mode)

    def train_show(self, mode=None):
        train_step_speed = 0 if self.train_step_speed is None else self.train_step_speed
        train_eval_speed = 0 if self.train_eval_speed is None else self.train_eval_speed
        res = train_step_speed * ((self.train_epoch - self.train_cur_epoch) * self.train_step + (
                self.train_step - self.train_cur_step)) + \
              train_eval_speed * (self.train_eval_step - self.train_cur_eval)
        res = self.time_trans(res, mode)
        if self.train_step_speed is None or self.train_eval_speed is None: res = res + '*'
        return train_step_speed, train_eval_speed, self.train_cur_step, self.train_step, res


class counter:
    def __init__(self):
        """
        计数器
        count(记录名称,累加次数): 当记录未上锁时会执行累加
        set(记录名称,次数): 设置记录的次数
        stop(记录名称,打印): 将记录上锁
        show(): 以字符串形式返回所有记录
        """
        self.counter = {}
        self.lock = {}

    def count(self, name: str, num: int = 1):
        if name not in self.counter.keys():
            self.counter[name] = num
            self.lock[name] = True
        elif self.lock[name]:
            self.counter[name] += num

    def set(self, name: str, num: int = 0):
        self.counter[name] = num
        self.lock[name] = True

    def stop(self, name: str, show=False):
        if name in self.lock:
            self.lock[name] = False
            if show:
                return self.counter[name]
        elif show:
            print("stop before set or count")
            return 0

    def show(self, end: str = '\n'):
        log = []
        for name, num in self.counter.items():
            log.append(f"{name} : {num}")
        return end.join(log)


class loss_marker:
    def __init__(self, save_path:str, draw=True):
        self.save_path = save_path + 'loss.jpg'
        self.draw = draw
        self.train_loss_marker = []
        self.eval_loss_marker = []
        self.train_loss = 0.
        self.eval_loss = 0.
        self.train_loss_cnt = 0
        self.eval_loss_cnt = 0

    def add(self, name: str, loss: float, batch_size:int=1):
        assert name in ['train', 'eval']
        if name == 'train':
            self.train_loss += loss
            self.train_loss_cnt += batch_size
        else:
            self.eval_loss += loss
            self.eval_loss_cnt += batch_size
        return self

    def ckpt(self,name=None, draw=None,epoch=1):
        if draw is None: draw = self.draw
        if name is None or name == 'train':
            avg = self.train_loss / self.train_loss_cnt if self.train_loss_cnt != 0 else 0
            self.train_loss_marker.append(avg)
            self.train_loss = 0.
            self.train_loss_cnt = 0
        if name is None or name == 'eval':
            avg = self.eval_loss / self.eval_loss_cnt if self.eval_loss_cnt != 0 else 0
            self.eval_loss_marker.append(avg)
            self.eval_loss = 0.
            self.eval_loss_cnt = 0
        if draw:
            plt.clf()
            x = np.linspace(epoch/len(self.train_loss_marker), epoch, len(self.train_loss_marker))
            y = np.array(self.train_loss_marker)
            plt.plot(x,y,color='red',linestyle='-',linewidth=1,label='train')
            x = np.linspace(epoch / len(self.eval_loss_marker), epoch, len(self.eval_loss_marker))
            y = np.array(self.eval_loss_marker)
            plt.plot(x, y, color='green', linestyle='-', linewidth=1, label='eval')
            plt.savefig(self.save_path)
        return self.last_loss()

    def last_loss(self):
        t = math.nan if len(self.train_loss_marker) == 0 else self.train_loss_marker[-1]
        e = math.nan if len(self.eval_loss_marker) == 0 else self.eval_loss_marker[-1]
        return t, e
