import torch
import os
import numpy as np
import random
from tensorboardX import SummaryWriter
from einops import repeat
from contextlib import contextmanager
import time
import matplotlib.pyplot as plt
from comet_ml import Experiment
import shutil
from collections import defaultdict

def seed_np_torch(seed=20010105):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # some cudnn methods can be random even after fixing the seed unless you tell it to be deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class Logger():
    def __init__(self, run_name) -> None:
        shutil.rmtree(f"runs/{run_name}", ignore_errors=True) # clear tensorboard log TODO: give a hint for new users
        self.use_comet = False

        self.writer = SummaryWriter(logdir=f"runs/{run_name}", flush_secs=1) # tensorboard writer

        if self.use_comet:
            # detect if use visualiation, remote servers don't have graphical interface
            if "HKRL_LOCAL_DEVICE" in os.environ:
                remote_train = False
            else:
                remote_train = True

            self.experiment = Experiment(
                api_key="M3cBRagHPybIm3UTo2JzPCkro",
                project_name="HKRL" if remote_train else "HKRL-dev",
                workspace="wp",
                auto_metric_logging=True,
                log_git_metadata=False,
                log_git_patch=False,
                log_env_details=False,
            )
            self.experiment.set_name(run_name)

        self.tag_step = defaultdict(int)

    def log(self, tag: str, value):
        self.tag_step[tag] += 1
        if value is None: # None refers to skip logging illeagl values, but still increase the step count
            return
        if "video" in tag:
            self.writer.add_video(tag, value, self.tag_step[tag], fps=15)
            if self.use_comet:
                assert False, "not implemented yet"
        elif "images" in tag:
            self.writer.add_images(tag, value, self.tag_step[tag])
            if self.use_comet:
                assert False, "not implemented yet"
        elif "hist" in tag:
            self.writer.add_histogram(tag, value, self.tag_step[tag])
            if self.use_comet:
                assert False, "not implemented yet"
        else:
            self.writer.add_scalar(tag, value, self.tag_step[tag])
            if self.use_comet:
                self.experiment.log_metric(tag, value, step=self.tag_step[tag])

class TensorboardLogger():
    def __init__(self, path) -> None:
        self.writer = SummaryWriter(logdir=path, flush_secs=1)
        self.tag_step = {}

    def log(self, tag: str, value):
        if tag not in self.tag_step:
            self.tag_step[tag] = 0
        else:
            self.tag_step[tag] += 1
        if "video" in tag:
            self.writer.add_video(tag, value, self.tag_step[tag], fps=15)
        elif "images" in tag:
            self.writer.add_images(tag, value, self.tag_step[tag])
        elif "hist" in tag:
            self.writer.add_histogram(tag, value, self.tag_step[tag])
        else:
            self.writer.add_scalar(tag, value, self.tag_step[tag])



def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.4])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 