# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import csv
import logging
import typing as tp
from pathlib import Path
import datetime
from collections import defaultdict

import torch
import wandb
from termcolor import colored
from torch.utils.tensorboard import SummaryWriter
from url_benchmark.hiplogs import HipLog


Formating = tp.List[tp.Tuple[str, str, str]]
COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
                       ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
                       ('episode_reward', 'R', 'float'),
                       ('fps', 'FPS', 'float'), ('total_time', 'T', 'time')]

COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
                      ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
                      ('episode_reward', 'R', 'float'),
                      ('total_time', 'T', 'time')]

COMMON_INF_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
                     ('obj', 'o', 'float'), ('constraints', 'c', 'float'),
                     ('lamb', 'lamb', 'float')]


pylogger = logging.getLogger(__name__)


class AverageMeter:
    def __init__(self) -> None:
        self._sum = 0.0
        self._count = 0

    def update(self, value: float, n: int = 1) -> None:
        self._sum += value
        self._count += n

    def value(self) -> float:
        return self._sum / max(1, self._count)


Metrics = tp.Dict[str, float]


class MetersGroup:
    def __init__(self, csv_file_name: tp.Union[Path, str], formating: Formating, use_wandb: bool) -> None:
        self._csv_file_name = Path(csv_file_name)
        self._formating = formating
        self._meters: tp.Dict[str, AverageMeter] = defaultdict(AverageMeter)
        self._csv_file: tp.Optional[tp.TextIO] = None
        self._csv_writer: tp.Optional[csv.DictWriter[str]] = None
        self.use_wandb = use_wandb

    def log(self, key: str, value: float, n: int = 1) -> None:
        self._meters[key].update(value, n)

    def _prime_meters(self) -> Metrics:
        data = {}
        for key, meter in self._meters.items():
            if key.startswith('train'):
                key = key[len('train') + 1:]
            elif key.startswith('eval'):
                key = key[len('eval') + 1:]
            elif key.startswith('inf'):
                key = key[len('inf') + 1:]
            key = key.replace('/', '_')
            data[key] = meter.value()
        return data

    def _remove_old_entries(self, data: Metrics) -> None:
        rows = []
        with self._csv_file_name.open('r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                if float(row['episode']) >= data['episode']:
                    break
                rows.append(row)
        with self._csv_file_name.open('w') as f:
            writer = csv.DictWriter(f,
                                    fieldnames=sorted(data.keys()),
                                    restval=0.0)
            writer.writeheader()
            for row in rows:
                writer.writerow(row)

    def _dump_to_csv(self, data: Metrics) -> None:
        if self._csv_writer is None:
            should_write_header = True
            if self._csv_file_name.exists():
                self._remove_old_entries(data)
                should_write_header = False

            self._csv_file = self._csv_file_name.open('a')
            self._csv_writer = csv.DictWriter(self._csv_file,
                                              fieldnames=sorted(data.keys()),
                                              restval=0.0)
            if should_write_header:
                self._csv_writer.writeheader()
        if self._csv_writer is None or self._csv_file is None:
            raise RuntimeError("CSV writer and file should have been instantiated")

        self._csv_writer.writerow(data)
        self._csv_file.flush()

    @staticmethod
    def _format(key: str, value: float, ty: str) -> str:
        if ty == 'int':
            value = int(value)
            return f'{key}: {value}'
        elif ty == 'float':
            return f'{key}: {value:.04f}'
        elif ty == 'time':
            value_ = str(datetime.timedelta(seconds=int(value)))
            return f'{key}: {value_}'
        raise ValueError(f'invalid format type: {ty}')

    def _dump_to_console(self, data: Metrics, prefix: str) -> None:
        prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
        pieces = [f'| {prefix: <14}']
        for key, disp_key, ty in self._formating:
            value = data.get(key, 0)
            pieces.append(self._format(disp_key, value, ty))
        print(' | '.join(pieces))

    @staticmethod
    def _dump_to_wandb(data: Metrics, step) -> None:
        wandb.log(data)

    def dump(self, step: int, prefix: str) -> None:
        if len(self._meters) == 0:
            return
        data = self._prime_meters()
        data['frame'] = step
        if self.use_wandb:
            wandb_data = {prefix + '/' + key: val for key, val in data.items()}
            self._dump_to_wandb(data=wandb_data, step=step)
        self._dump_to_csv(data)
        # self._dump_to_console(data, prefix)
        self._meters.clear()

class InferenceLogger:
    def __init__(self, log_dir: Path, use_tb: bool, use_wandb: bool, use_hiplog: bool) -> None:
        self._log_dir = log_dir        
        self._inf_mg = MetersGroup(log_dir / 'inf.csv',
                                    formating=COMMON_EVAL_FORMAT,
                                    use_wandb=use_wandb)
        self._sw: tp.Optional[SummaryWriter] = None
        # self.hiplog: tp.Optional[HipLog] = None
        self.use_hiplog = use_hiplog
        if use_hiplog:
            self.hiplog = HipLog(log_dir / "hip.log")
        if use_tb:
            self._sw = SummaryWriter(str(log_dir / 'tb'))
        self.use_wandb = use_wandb
        self.step = 0

    def _try_sw_log(self, key, value) -> None:
        if self._sw is not None:
            self._sw.add_scalar(key, value, self.step)

    def log(self, key: str, value: tp.Union[float, torch.Tensor]) -> None:
        assert key.startswith('inf') 
        if isinstance(value, torch.Tensor):
            value = value.item()
        self._try_sw_log(key, value)
        mg = self._inf_mg
        mg.log(key, value)
        if self.use_hiplog:
            self.hiplog(**{key: value})

    def log_metrics(self, metrics: tp.Dict[str, float]) -> None:
        for key, value in metrics.items():
            self.log(f'inf/{key}', value)
        self.step+=1

    def dump(self) -> None:
        self._inf_mg.dump(self.step, 'inf')


    def log_and_dump_ctx(self, ty: str) -> "LogAndDumpCtx":
        logdump =  LogAndDumpCtx(self, self.step, ty)
        self.step +=1
        return logdump
    
class ActorLogger:
    def __init__(self, log_dir: Path, use_tb: bool, use_wandb: bool, use_hiplog: bool) -> None:
        self._log_dir = log_dir        
        self._actor_mg = MetersGroup(log_dir / 'actor.csv',
                                    formating=COMMON_EVAL_FORMAT,
                                    use_wandb=use_wandb)
        self._sw: tp.Optional[SummaryWriter] = None
        # self.hiplog: tp.Optional[HipLog] = None
        self.use_hiplog = use_hiplog
        if use_hiplog:
            self.hiplog = HipLog(log_dir / "hip.log")
        if use_tb:
            self._sw = SummaryWriter(str(log_dir / 'tb'))
        self.use_wandb = use_wandb
        self.step = 0

    def _try_sw_log(self, key, value) -> None:
        if self._sw is not None:
            self._sw.add_scalar(key, value, self.step)

    def log(self, key: str, value: tp.Union[float, torch.Tensor]) -> None:
        assert key.startswith('actor') 
        if isinstance(value, torch.Tensor):
            value = value.item()
        self._try_sw_log(key, value)
        mg = self._actor_mg
        mg.log(key, value)
        if self.use_hiplog:
            self.hiplog(**{key: value})

    def log_metrics(self, metrics: tp.Dict[str, float]) -> None:
        for key, value in metrics.items():
            self.log(f'actor/{key}', value)
        self.step+=1

    def dump(self) -> None:
        self._actor_mg.dump(self.step, 'actor')


    def log_and_dump_ctx(self, ty: str) -> "LogAndDumpCtx":
        logdump =  LogAndDumpCtx(self, self.step, ty)
        self.step +=1
        return logdump

class Logger:
    def __init__(self, log_dir: Path, use_tb: bool, use_wandb: bool, use_hiplog: bool) -> None:
        self._log_dir = log_dir

        self._train_mg = MetersGroup(log_dir / 'train.csv',
                                     formating=COMMON_TRAIN_FORMAT,
                                     use_wandb=use_wandb)
        self._eval_mg = MetersGroup(log_dir / 'eval.csv',
                                    formating=COMMON_EVAL_FORMAT,
                                    use_wandb=use_wandb)
        
        # self._inf_mg = MetersGroup(log_dir / 'inf.csv',
        #                             formating=COMMON_EVAL_FORMAT,
        #                             use_wandb=use_wandb)
        self._sw: tp.Optional[SummaryWriter] = None
        # self.hiplog: tp.Optional[HipLog] = None
        self.use_hiplog = use_hiplog
        if use_hiplog:
            self.hiplog = HipLog(log_dir / "hip.log")
        if use_tb:
            self._sw = SummaryWriter(str(log_dir / 'tb'))
        self.use_wandb = use_wandb

    def _try_sw_log(self, key, value, step) -> None:
        if self._sw is not None:
            self._sw.add_scalar(key, value, step)

    def log(self, key: str, value: tp.Union[float, torch.Tensor], step: int) -> None:
        assert key.startswith('train') or key.startswith('eval')
        if isinstance(value, torch.Tensor):
            value = value.item()
        self._try_sw_log(key, value, step)
        mg = self._train_mg if key.startswith('train') else self._eval_mg
        mg.log(key, value)
        if self.use_hiplog:
            self.hiplog(**{key: value})

    def log_metrics(self, metrics: tp.Dict[str, float], step: int, ty: str) -> None:
        for key, value in metrics.items():
            self.log(f'{ty}/{key}', value, step)

    def dump(self, step, ty=None) -> None:
        try:
            if ty is None or ty == 'eval':
                self._eval_mg.dump(step, 'eval')
            if ty is None or ty == 'train':
                self._train_mg.dump(step, 'train')
            if ty is None or ty == 'inf':
                self._inf_mg.dump(step, 'inf')
        except ValueError as e:
            pylogger.warning(f"Could not dump metrics: {e}")

    def log_and_dump_ctx(self, step: int, ty: str) -> "LogAndDumpCtx":
        return LogAndDumpCtx(self, step, ty)


class LogAndDumpCtx:
    def __init__(self, logger: Logger, step: int, ty: str) -> None:
        self._logger = logger
        self._step = step
        self._ty = ty

    def __enter__(self) -> "LogAndDumpCtx":
        return self

    def __call__(self, key: str, value: float) -> None:
        self._logger.log(f'{self._ty}/{key}', value, self._step)

    def __exit__(self, *args: tp.Any) -> None:
        self._logger.dump(self._step, self._ty)
