import os

os.environ["OMP_NUM_THREADS"] = "1"

import subprocess
import tempfile
import time
import shutil
import random
import yappi
import hydra
import json
import torch.multiprocessing as mp
import numpy as np
import matplotlib.pyplot as plt
# from gametrackr.core import TensorBoardLogger, get_class
# from gametrackr.dbs.ondisk.python_object import DiskPythonObjectDB
# from gametrackr.dbs.pytorch import PytorchOnDiskEpisodesDB
# from omegaconf import DictConfig
# from omegaconf import open_dict

# from scripts.utils.application import start_server_with_model_idx
# from scripts.utils.godot_evaluation import single_evaluation

# import tempfile
# import hydra
# import numpy as np
import torch
import uvicorn
import io
import base64
# import torch.multiprocessing as mp
# import time
# import random
# import gym
# import logging
import copy
from PIL import Image
from pydantic import BaseModel
from fastapi import FastAPI, APIRouter, BackgroundTasks, WebSocket, WebSocketDisconnect
from omegaconf import DictConfig, open_dict, ListConfig
from torchvision import transforms

# from sgcrl.utils.evaluation import evaluation_loop
from sgcrl.utils.logger import TensorBoardLogger
from sgcrl.utils.imports import get_class
from sgcrl.data.dbs.on_disk import DiskPythonObjectDB, PytorchOnDiskEpisodesDB
# from sgcrl.utils.d4rl import build_d4rl_dataset_v2


to_tensor = transforms.ToTensor()

def convert_dict_to_tensor(event):
    if event["dtype"] == "int":
        t = torch.tensor(event["data"]).long()
    elif event["dtype"] == "float" or event["dtype"] == "double":
        t = torch.tensor(event["data"]).float()
    elif event["dtype"] == "bool":
        t = torch.tensor(event["data"]).bool()
    shape = list(event["shape"])
    if len(shape) > 1:
        t = t.reshape(*shape)
    return t

def event_to_pytorch(event):
    results = {}
    for k, v in event.items():
        if isinstance(v, dict):
            if "shape" in v and "data" in v and "dtype" in v:
                v = convert_dict_to_tensor(v)
                results[k] = v
            else:
                r = event_to_pytorch(v)
                for _k, _v in r.items():
                    results[k + "/" + _k] = _v
        elif isinstance(v, torch.Tensor):
            pass
        elif isinstance(v, int):
            v = torch.tensor(v).long()
            results[k] = v
        elif isinstance(v, float):
            v = torch.tensor(v).float()
            results[k] = v
        elif isinstance(v, np.ndarray):
            v = torch.from_numpy(v)
            results[k] = v
        elif isinstance(v, str):
            if k.startswith("image"):
                decoded_bytes = base64.b64decode(v)
                byte_array = bytearray(decoded_bytes)
                image_file = io.BytesIO(byte_array)
                image = Image.open(image_file)
                v = to_tensor(image)
                results[k] = v
            else:
                pass  # results[entity+"/"+ncomponent]=torch.tensor(component).float()
    return results

router = APIRouter()
application_manager = None
websocket_pytorch = False
convert_event_to_pytorch = event_to_pytorch
debug = False

class GameStartEvent(BaseModel):
    game_name: str
    infos: dict
    received_timestamp: str = None

class GameStartEvent(BaseModel):
    game_name: str
    infos: dict
    received_timestamp: str = None

class GameStartEventID(BaseModel):
    game_name: str
    infos: dict
    received_timestamp: str = None
    application_id: str

class GameEndEvent(BaseModel):
    application_id: str
    received_timestamp: str = None

class SessionStartEvent(BaseModel):
    application_id: str
    session_name: str
    infos: dict
    received_timestamp: str = None

class SessionStartEventID(BaseModel):
    application_id: str
    session_id: str
    session_name: str
    infos: dict
    received_timestamp: str = None
    configuration: dict

class SessionEndEvent(BaseModel):
    application_id: str
    session_id: str
    received_timestamp: str = None

class TrackingEvent(BaseModel):
    application_id: str
    session_id: str
    event_name: str
    event: dict

class ServingEvent(BaseModel):
    application_id: str
    session_id: str
    decorator_name: str
    inputs: dict

def serialize(event):
    return json.dumps(event)


def deserialize(msg):
    j = json.loads(msg)
    return j

@router.post("/test")
async def test():
    print('YES')
    return {"msg", "working"}

@router.post("/start_game")
async def start_game(event: GameStartEvent):
    t = str(time.time())

    event = copy.deepcopy(event)
    event.received_timestamp = t
    results = application_manager.application_start(event.game_name, event.infos)

    if debug:
        print("START NEW GAME: ", results)
    return results


@router.post("/end_game")
async def end_game(event: GameEndEvent):
    return {"msg": "game_ended"}


@router.post("/end_session")
async def end_session(event: SessionEndEvent, background_tasks: BackgroundTasks):
    if debug:
        print("END SESSION: ", event)
    background_tasks.add_task(
        application_manager.session_end, event.application_id, event.session_id
    )
    return {"msg": "session_ended"}


@router.post("/start_session")
async def start_session(event: SessionStartEvent):
    results = application_manager.session_start(
        event.application_id, event.session_name, event.infos
    )
    if debug:
        print("STARTING SESSION ", results)
    return results


@router.websocket("/event")
async def data(websocket: WebSocket):
    await websocket.accept()
    if debug:
        print("[WebSocket] /event Connection accepted")
    n_event = {}
    f = {}
    try:
        while True:
            _st = None
            if websocket_pytorch:
                data = await websocket.receive_bytes()
                _st = time.time()
                buffer = io.BytesIO(data)
                event = torch.load(buffer)
            else:
                data = await websocket.receive_text()
                _st = time.time()
                event = deserialize(data)
            application_id = event.pop("application_id")
            session_id = event.pop("session_id")
            name = event["event_name"]
            event = event["event"]
            tevent = convert_event_to_pytorch(event)
            new_event, subgoal_reward = application_manager._serve(
                application_id, session_id, name, tevent
            )
            # Serving
            if not new_event is None:
                if websocket_pytorch:
                    buffer = io.BytesIO()
                    torch.save(new_event, buffer)
                    await websocket.send_bytes(buffer.getvalue())
                else:
                    new_event_s = serialize(new_event)
                    # print(new_event_s)
                    await websocket.send_text(new_event_s)
                event[name] = new_event
                event['subgoal_reward'] = subgoal_reward
                tevent = event_to_pytorch(event)
                application_manager._push_frame(
                    application_id, session_id, name, tevent
                )
            else:
                application_manager._push_frame(
                    application_id, session_id, name, tevent
                )
                if websocket_pytorch:
                    _et = time.time()
                    buffer = io.BytesIO()
                    torch.save({}, buffer)
                    await websocket.send_bytes(buffer.getvalue())
                else:
                    await websocket.send_text("{}")

    except WebSocketDisconnect:
        if debug:
            print("Client disconnected")



def new_background(map_img_path, rotate, x_min, x_max, y_min, y_max, flip_x):
    fig, ax = plt.subplots()
    # Load background image if provided
    if map_img_path:
        map_img = plt.imread(map_img_path)
        if rotate:
            map_img = np.rot90(map_img, rotate)
        if flip_x:
            map_img = np.flip(map_img, axis=1)
        ax.imshow(map_img, extent=[x_min, x_max, y_min, y_max])
    return ax

def visu_eval(evaluation_db, map_img_path, x_min=-60, x_max=60, y_min=-60, y_max=60, rotate=2, flip_x=True,
              add_traj=False, fail_only=False, add_start=False, save_path=None, tag=""):
    episodes = evaluation_db.get_ids()
    # Assign unique colors to players
    player_colors = plt.cm.tab10(np.linspace(0, 1, len(episodes)))

    ax = new_background(map_img_path, rotate, x_min, x_max, y_min, y_max, flip_x)



    # plot player start position and goal position
    #fail_ep = []
    if not fail_only:
        for i, episode_id in enumerate(episodes):
            episode_data = evaluation_db[episode_id]
            start_position = episode_data["sensor/position"][0]
            goal_position = episode_data["sensor/absolute_goal_position"][0]
            g_color = "green" if episode_data["sensor/touch_goal"][-1].all() else "red"
            # if g_color == "red":
            #     fail_ep.append(episode_id)
            color = player_colors[i]  # Use unique color for each player
            if add_start:
                plt.plot(start_position[0], start_position[2], "o", color=color, label='Start Position' if i == 0 else None)
                plt.scatter(goal_position[0], goal_position[2], marker="*", color=g_color, edgecolors=color, label='Goal Position' if i == 0 else None, s=50)
            else:
                plt.scatter(goal_position[0], goal_position[2], marker="*", color=g_color, label='Goal Position' if i == 0 else None, s=50)
        #plt.plot(goal_position[0], goal_position[2], "*", color=color, edgecolor=g_color, label='Goal Position' if i == 0 else None)

    # plot failed player trajectories
    plot_idx = 0
    nb_trajs = 0
    if add_traj:
        colors = ["red","blue","green","orange","black","yellow","pink"]
        for i, episode_id in enumerate(episodes):
            # if fail_only and episode_id not in fail_ep:
            #     continue
            episode_data = evaluation_db[episode_id]
            start_position = episode_data["sensor/position"][0]
            goal_position = episode_data["sensor/absolute_goal_position"][0]
            g_color = "green" if episode_data["sensor/touch_goal"][-1].all() else "red"
            if g_color == "green" and fail_only:
                continue
            color = colors[nb_trajs]  # Use unique color for each player
            if add_start:
                plt.plot(start_position[0], start_position[2], "o", color=color, label='Start Position' if i == 0 else None)
                plt.scatter(goal_position[0], goal_position[2], marker="*", color=g_color, edgecolors=color, label='Goal Position' if i == 0 else None, s=50)
            else:
                plt.scatter(goal_position[0], goal_position[2], marker="*", color=g_color, label='Goal Position' if i == 0 else None, s=50)

            positions = episode_data["sensor/position"]
            x = [pos[0] for pos in positions]
            y = [pos[2] for pos in positions]
            #color = player_colors[nb_trajs]  # Use unique color for each player
            ax.plot(x, y, alpha=0.5, color=color)
            nb_trajs += 1
            if nb_trajs >= 5:
                # save and start new plot to avoid cluttered visualizations
                ax.set_xlabel('X Position')
                ax.set_ylabel('Y Position')
                ax.set_title('Start and Goal Positions')
                if save_path:
                    plt.tight_layout()
                    plt.savefig(save_path + os.sep + "test_goals_" + tag + "_" + str(plot_idx))
                else:
                    plt.show()
                plt.close()
                ax = new_background(map_img_path, rotate, x_min, x_max, y_min, y_max, flip_x)
                plot_idx += 1
                nb_trajs = 0


    ax.set_xlabel('X Position')
    ax.set_ylabel('Y Position')
    ax.set_title('Start and Goal Positions')
    if save_path:
        plt.tight_layout()
        plt.savefig(save_path+os.sep+"test_goals_"+tag+"_"+str(plot_idx))
    else:
        plt.show()


def single_evaluation(db, cfg, logger, epoch):
    ids = db.get_ids()
    while not cfg.n_episodes is None and len(ids) < cfg.n_episodes:
        time.sleep(1)
        ids = db.get_ids()
        print(f'Size of generated episodes db: {len(ids)}')

    lengths = []
    n_ok = 0
    rewards=[]
    subgoal_rewards=[]
    has_reward=False
    has_subgoal_reward=False
    metrics = {'model_idx': epoch}

    # compute visualizations before computing metrics and emptying the eval db
    # if not os.path.exists(cfg.graphics_dir):
    #     os.makedirs(cfg.graphics_dir)
    # visu_eval(db, "godot_envs/navbot_large/map.PNG", save_path=cfg.graphics_dir, tag=f"_{epoch}")
    # visu_eval(db, "godot_envs/navbot_large/map.PNG",
    #           fail_only=True, add_traj=True, add_start=True, save_path=cfg.graphics_dir, tag=f"fail_only_{epoch}")
    time.sleep(5)
    for _id in ids:
        episode = db.pop(_id)
        fk = next(episode.__iter__())
        T = episode[fk].size()[0]
        lengths.append(T)

        if episode["sensor/touch_goal"][-1].all():
            n_ok += 1
        else:
            pass
        if "reward" in episode:
            has_reward=True
            rewards.append(episode["reward"].sum().item())
        if 'subgoal_reward' in episode:
            has_subgoal_reward=True
            subgoal_rewards.append(episode['subgoal_reward'].max())

    # print("Success rate = ", n_ok / len(lengths)) #," Reward = ",numpy.mean(rewards))
    logger.add_scalar("avg_length", np.mean(lengths), epoch)
    if has_reward: 
        logger.add_scalar("avg_reward", np.mean(rewards), epoch)
        metrics["avg_reward"] = np.mean(rewards)
        print("success_rate", n_ok / len(lengths))
    if has_subgoal_reward: 
        logger.add_scalar("success_rate_subgoals", np.mean(subgoal_rewards), epoch)
        metrics['success_rate_subgoals'] = np.mean(subgoal_rewards)
        print(f'Success rate {n_ok / len(lengths):.3f}, Success rate subgoals {np.mean(subgoal_rewards):.3f}')
    logger.add_scalar("success_rate", n_ok / len(lengths), epoch)
    metrics["success_rate"] = n_ok / len(lengths)
    logger.add_scalar("on_n_episodes", len(ids), epoch)
    metrics["on_n_episodes"] = len(ids)

    return metrics

def no_delete_single_evaluation(db, cfg, logger, epoch):
    ids = db.get_ids()
    while not cfg.n_episodes is None and len(ids) < cfg.n_episodes:
        time.sleep(1)
        ids = db.get_ids()
        print(f'Size of generated episodes db: {len(ids)}')

    lengths = []
    n_ok = 0
    rewards=[]
    subgoal_rewards=[]
    has_reward=False
    has_subgoal_reward=False
    metrics = {'model_idx': epoch}

    # compute visualizations before computing metrics and emptying the eval db
    # if not os.path.exists(cfg.graphics_dir):
    #     os.makedirs(cfg.graphics_dir)
    # visu_eval(db, "godot_envs/navbot_large/map.PNG", save_path=cfg.graphics_dir, tag=f"_{epoch}")
    # visu_eval(db, "godot_envs/navbot_large/map.PNG",
    #           fail_only=True, add_traj=True, add_start=True, save_path=cfg.graphics_dir, tag=f"fail_only_{epoch}")
    episodes = []
    time.sleep(5)
    for _id in ids:
        episode = db.pop(_id)
        episodes.append(copy.deepcopy(episode))
        fk = next(episode.__iter__())
        T = episode[fk].size()[0]
        lengths.append(T)

        if episode["sensor/touch_goal"][-1].all():
            n_ok += 1
        else:
            pass
        if "reward" in episode:
            has_reward=True
            rewards.append(episode["reward"].sum().item())
        if 'subgoal_reward' in episode:
            has_subgoal_reward=True
            subgoal_rewards.append(episode['subgoal_reward'].max())

    # print("Success rate = ", n_ok / len(lengths)) #," Reward = ",numpy.mean(rewards))
    logger.add_scalar("avg_length", np.mean(lengths), epoch)
    if has_reward: 
        logger.add_scalar("avg_reward", np.mean(rewards), epoch)
        metrics["avg_reward"] = np.mean(rewards)
        print("success_rate", n_ok / len(lengths))
    if has_subgoal_reward: 
        logger.add_scalar("success_rate_subgoals", np.mean(subgoal_rewards), epoch)
        metrics['success_rate_subgoals'] = np.mean(subgoal_rewards)
        print(f'Success rate {n_ok / len(lengths):.3f}, Success rate subgoals {np.mean(subgoal_rewards):.3f}')
    logger.add_scalar("success_rate", n_ok / len(lengths), epoch)
    metrics["success_rate"] = n_ok / len(lengths)
    logger.add_scalar("on_n_episodes", len(ids), epoch)
    metrics["on_n_episodes"] = len(ids)

    return metrics, episodes


def convert(d):
    if isinstance(d, ListConfig) or isinstance(d, list):
        return [convert(e) for e in d]
    elif isinstance(d, DictConfig) or isinstance(d, dict):
        return {k: convert(v) for k, v in d.items()}
    else:
        return d
    
class Application:
    """This is the base structure to implement a python object that interacts with a drainC client. It provides few methods to simplify the implementation of classical functionnalities"""

    def __init__(self):
        pass

    def application_start(self, name, infos):
        raise NotImplementedError

    def session_start(self, application_id, session_name, session_infos):
        raise NotImplementedError

    def session_end(self, application_id, session_id):
        raise NotImplementedError

    def application_end(self):
        raise NotImplementedError


class GodotApplication(Application):
    def __init__(
            self,
            episodes_db_per_session_name,
            model_db,
            sessions_configurations,
            bot_arguments={},
            model_idx=None,
    ):
        super().__init__()

        self._model_db = model_db
        self._model_idx = model_idx
        # We have a DB per session_name (player_0, player_1,...) to say where we want to save episodes
        # It is used to distinguish between training and evaluation
        self._episodes_db_per_session_name = episodes_db_per_session_name

        self._current_idx = None

        # The models (aka nn) to use for each session_name
        self._n_events = {}
        self._models = {}
        self._frames = {}

        # The configuration of the sessions (one per session_name)
        self._configurations = sessions_configurations
        self._session_names = {}
        self._bot_arguments = bot_arguments
        if self._bot_arguments is None: self._bot_arguments = {}

    def application_start(self, name, infos):
        _id = "AID_" + str(time.time())
        return {"application_id": _id}

    def session_start(self, application_id, session_name, session_infos):
        session_id = "SID_" + str(time.time())
        self._session_names[(application_id, session_id)] = session_name
        configuration = self._configurations[session_name]
        r = {
            "application_id": application_id,
            "session_id": session_id,
            "configuration": configuration,
        }
        return r

    def _save_episode(self, a, s, e, frame):
        # To save episodes in the right database
        assert e == "action"
        session_name = self._session_names[(a, s)]
        if session_name in self._episodes_db_per_session_name:
            episode = {
                k: torch.cat(
                    [f[k].unsqueeze(0) for f in self._frames[(a, s, e)]], dim=0
                )
                for k in self._frames[(a, s, e)][0]
            }
            self._episodes_db_per_session_name[session_name].write(a, s, e, episode)
        yappi.stop()
        yappi.get_func_stats().save(f'func.prof', type='pstat')

    def session_end(self, application_id, session_id):
        keys = list(self._frames.keys())
        for a, s, e in keys:
            if (a == application_id) and (s == session_id):
                self._save_episode(a, s, e, self._frames[(a, s, e)])
                del self._frames[(a, s, e)]
                if (a, s, e) in self._n_events:
                    del self._n_events[(a, s, e)]
                if (a, s, e) in self._models:
                    del self._models[(a, s, e)]

    def _push_frame(self, application_id, session_id, episode_id, frame):
        if not (application_id, session_id, episode_id) in self._frames:
            self._frames[(application_id, session_id, episode_id)] = []
        self._frames[(application_id, session_id, episode_id)].append(frame)

    def application_end(self, application_id):
        pass

    def _serve(self, application_id, session_id, decorator_name, event):
        if decorator_name == "action":
            session_name = self._session_names[(application_id, session_id)]
            if self._configurations[session_name]["decorators"][decorator_name][
                "type"
            ].startswith("serve"):
                key = (application_id, session_id, decorator_name)

                # initializing the serving function to track the index of the event
                if not key in self._n_events:
                    self._n_events[key] = 0
                    s = self._model_db.size("model")
                    if s > 0:
                        try:
                            if self._model_idx is None:
                                _model = self._model_db.get_last("model")
                            else:
                                _model = self._model_db.get("model", self._model_idx)
                            _model.reset(seed=453)
                            _model.eval()
                            self._models[key] = _model
                        except Exception as e:
                            print(e)
                else:
                    self._n_events[key] += 1
                if not key in self._models:
                    action = {}
                    action["move_right"] = 0.0
                    action["move_left"] = 0.0
                    action["move_forwards"] = 0.0
                    action["move_backwards"] = 0.0
                    action["run"] = True
                    action["jump"] = False
                    action["rotation"] = 0.0
                    return action
                else:
                    try:
                        action = self._models[key]._action(event, **self._bot_arguments)
                    except Exception as e:
                        print(e)
                    return action, int(self._models[key].current_phase == 2)
        return None


def start_server_with_model_idx(dbs, application_configuration, model_db, model_idx):
    import logging
    yappi.start()
    global application_manager
    application_configuration = convert(application_configuration)
    app = FastAPI()
    print("======================================================")
    print("Configuration = ", application_configuration["tracking_configuration"])
    print("======================================================")
    ba = {}
    if "bot_arguments" in application_configuration and not application_configuration["bot_arguments"] is None:
        ba = application_configuration["bot_arguments"]
    application_manager = GodotApplication(
        dbs,
        model_db,
        application_configuration["tracking_configuration"],
        bot_arguments=ba,
        model_idx=model_idx,
    )
    app.include_router(router)
    uvicorn_log_config = uvicorn.config.LOGGING_CONFIG
    del uvicorn_log_config["loggers"]["uvicorn"]
    logging.getLogger("uvicorn").setLevel(logging.CRITICAL)

    uvicorn.run(
        app,
        host=application_configuration["host"],
        port=application_configuration["port"],
        log_config=None,
    )
