from dataclasses import asdict
import json
import pathlib
from pathlib import Path
import random
from collections import deque
from collections.abc import Callable
from typing import Any
import uuid

import gymnasium as gym
import numpy as np
import torch as th
from tqdm import tqdm
import sys
import json

from regawa import GNNParams, GroundValue
from regawa.gnn import ActionMode
from regawa.gnn.data import heterostatedata_from_obslist
from regawa.gnn.gnn_agent import (
    AgentConfig,
    GraphAgent,
    heterostatedata_to_tensors,
)
from regawa.model.base_grounded_model import BaseGroundedModel
from regawa.model.base_model import BaseModel
from regawa.model.utils import max_arity
from regawa.rddl import register_env
from regawa.rddl.rddl_utils import rddl_ground_to_tuple
from regawa.rl.util import calc_loss, update
from regawa.wrappers.graph_utils import create_graphs_func, create_obs_dict_func
from regawa.wrappers.remove_false_wrapper import remove_false
from regawa.wrappers.render_utils import create_render_graph
from regawa.wrappers.utils import from_dict_action, object_list

from rddleval.eval_data import EvalEntry
from rddleval.scripts.evaluate import evaluate_instance

RecordingObs = dict[str, Any]
RecordingAction = dict[str, int]
RecordingEntry = dict[RecordingObs, RecordingAction]
Recording = list[RecordingEntry]

GroundAction = dict[GroundValue, bool]
GroundObs = dict[GroundValue, Any]

IndexedAction = tuple[int, ...]


@th.inference_mode()
def save_sorted_losses(
    model: BaseModel,
    agent: GraphAgent,
    expert_actions: list[GroundAction],
    indexed_expert_obs: list[GroundObs],
    expert_obs: list[GroundObs],
    device: str = "cpu",
):
    loss_per_obs = []
    for i, (expert_a, d, o) in enumerate(
        zip(expert_actions, indexed_expert_obs, expert_obs)
    ):
        s = heterostatedata_to_tensors(heterostatedata_from_obslist([d]), device=device)
        g = to_graph(o, model)
        actions, logprob, _, _, p_a, p_n__a = agent.sample(s, deterministic=True)
        l2_norms = [th.sum(th.square(w)) for w in agent.parameters()]
        loss = calc_loss(l2_norms, logprob).item()

        objs = object_list(o.keys(), model.fluent_param)
        objs = [o.name for o in objs]

        model_a = from_index_action(actions[0], lambda x: objs[x], model)

        factor_weights = p_n__a.T[actions[:, 0]].detach().squeeze().cpu().numpy()

        weight_by_factor = {
            k: f"{float(v):0.3f}"
            for k, v in zip(g.factor_labels, factor_weights)
            if v > 0.001
        }

        weight_by_action = {
            k: f"{float(v):0.3f}"
            for k, v in zip(
                model.action_fluents,
                p_a.detach().squeeze().cpu().numpy(),
            )
            if v > 0.001
        }

        loss_per_obs.append(
            dict(
                model_action=model_a,
                expert_action=list(expert_a.keys())[0],
                loss=loss,
                action_probs=weight_by_action,
                object_probs=weight_by_factor,
                step=i,
                # obs=o,
            )
        )

    sorted_loss = sorted(loss_per_obs, key=lambda x: x["loss"], reverse=True)

    return sorted_loss


def ground_to_tuple(s: str) -> GroundValue:
    return rddl_ground_to_tuple(s)


def convert_state_to_tuples(
    d: RecordingObs, converter_func: Callable[[str], GroundValue]
) -> dict[GroundValue, Any]:
    return {converter_func(k): v for k, v in d.items()}


def convert_actions_to_tuples(
    d: RecordingAction, converter_func: Callable[[str], GroundValue]
) -> dict[GroundValue, bool]:
    return convert_state_to_tuples(d, converter_func) if d else {("None", "None"): True}


def ensure_tuple(x: tuple[str, ...]) -> tuple[str, ...]:
    return x + ("None",) if len(x) == 1 else x


def from_index_action(
    action: IndexedAction, idx_to_obj: Callable[[int], str], model: BaseModel
) -> tuple[str, str]:
    return (model.action_fluents[action[0]], idx_to_obj(action[1]))


def to_indexed_action(
    action: GroundAction, obj_to_idx: Callable[[str], int], model: BaseModel
) -> IndexedAction:
    action = list(action.keys())[0] if action else ("None", "None")
    a = from_dict_action(action, lambda x: model.action_fluents.index(x), obj_to_idx)
    return a


render_index = 0


def get_actions(data: Recording):
    return [[x["actions"] for x in d] for d in data]


def get_obs(data: Recording):
    return [[x["state"] for x in d] for d in data]


def to_graph(obs: GroundObs, model: BaseModel, create_graphs):
    g, _ = create_graphs(obs)
    return create_render_graph(g.boolean, g.numeric)


def to_obsdata(
    obs: GroundObs,
    action: GroundAction,
    model: BaseModel,
    create_graphs,
    create_obs_dict,
):
    g, _ = create_graphs(
        obs,
    )
    o = create_obs_dict(g)
    a = to_indexed_action(action, lambda x: g.boolean.factors.index(x), model)

    # Rendering
    # dot = to_graphviz(create_render_graph(g.boolean, g.numeric))
    # global render_index
    # render_path = pathlib.Path("saved_render")
    # render_path.mkdir(exist_ok=True)
    # with open(render_path / f"graph_{render_index}.dot", "w") as f:
    #     f.write(dot)
    # render_index += 1

    return o, a


def get_agent(model: BaseModel, embedding_dim: int, device: str = "cpu"):
    n_types = model.num_types
    n_relations = model.num_fluents
    n_actions = model.num_actions
    arity = max_arity(model)

    params = GNNParams(
        layers=4,
        embedding_dim=embedding_dim,
        activation=th.nn.Tanh(),
        aggregation="max",
        action_mode=ActionMode.ACTION_THEN_NODE,
    )

    config = AgentConfig(
        n_types,
        n_relations,
        n_actions,
        hyper_params=params,
        arity=arity,
        remove_false_fluents=True,
    )

    agent = GraphAgent(config, None, device)

    agent = agent.to(device)

    return agent


def get_rddl_data(data: Recording, model: BaseModel, grounded_model: BaseGroundedModel):
    create_graphs = create_graphs_func(model)
    data = [convert_episode(d) for d in data]
    rollout = [to_obsdata(s, model, grounded_model) for e in data for s in e]
    return zip(*rollout)


def test_saved_data(domain: str, data_path: str, batch_id: str):
    datafile = Path(data_path).expanduser()
    seed = 1
    device = "cuda:0" if th.cuda.is_available() else "cpu"

    env_id = register_env()

    run_id = str(uuid.uuid4())

    learning_rate = 1e-3
    wd = 0.0
    steps = 5000
    embedding_dim = 16

    config = {
        "learning_rate": learning_rate,
        "weight_decay": wd,
        "steps": steps,
        "seed": seed,
        "data_path": data_path,
        "embedding_dim": embedding_dim,
    }

    output_dir = pathlib.Path("imitation_output")
    output_dir.mkdir(exist_ok=True)
    batch_dir = output_dir / batch_id
    batch_dir.mkdir(exist_ok=True)
    run_dir = batch_dir / run_id
    run_dir.mkdir()

    env: gym.Env = gym.make(env_id, domain=domain, instance=1, remove_false=True)
    model: BaseModel = env.unwrapped.model

    np.random.seed(seed)
    th.manual_seed(seed)
    random.seed(seed)

    agent = get_agent(model, embedding_dim, device)
    optimizer = th.optim.AdamW(
        agent.parameters(), lr=learning_rate, amsgrad=True, weight_decay=wd
    )

    with open(datafile, "r") as f:
        expert_data = json.load(f)

    expert_actions = [x["actions"] for x in expert_data]
    expert_obs = [x["state"] for x in expert_data]

    to_tuple = lambda x: tuple(x.split("__"))
    wrapper_func = lambda x: remove_false(
        convert_state_to_tuples(
            x,
            to_tuple,
        )
    )
    expert_actions = [convert_actions_to_tuples(e, to_tuple) for e in expert_actions]
    expert_actions = [
        {ensure_tuple(k): v} for e in expert_actions for k, v in e.items()
    ]
    expert_obs = [wrapper_func(e) for e in expert_obs]

    create_graphs = create_graphs_func(model)
    create_obs_dict = create_obs_dict_func(model)

    indexed_expert_obs, indexed_expert_action = zip(
        *[
            to_obsdata(o, a, model, create_graphs, create_obs_dict)
            for o, a in zip(expert_obs, expert_actions)
        ]
    )

    d = heterostatedata_to_tensors(
        heterostatedata_from_obslist(indexed_expert_obs), device=device
    )
    indexed_expert_action = th.as_tensor(
        indexed_expert_action, dtype=th.int64, device=device
    )
    avg_loss = 0.0
    avg_grad_norm = 0.0
    pbar = tqdm()
    grad_norms = deque()
    losses = deque()
    for _ in range(steps):

        if steps > 2000:
            # set the learning rate to 1e-4 after 2000 steps
            for param_group in optimizer.param_groups:
                param_group["lr"] = 1e-4

        loss, grad_norm, _ = update(
            agent, optimizer, indexed_expert_action, d, max_grad_norm=0.5
        )
        pbar.update(1)
        avg_loss = avg_loss + (loss - avg_loss) / 2
        avg_grad_norm = avg_grad_norm + (grad_norm - avg_grad_norm) / 2
        grad_norms.append(grad_norm)
        losses.append(loss)
        pbar.set_description(f"Loss: {avg_loss:.3f}, Grad Norm: {avg_loss:.3f}")

    pbar.close()

    agent_path = str(run_dir / f"model_{run_id}.pth")
    config_path = str(run_dir / f"config_{run_id}.json")
    agent.save_agent(agent_path)
    with open(config_path, "w") as f:
        json.dump(config, f)

    instances = range(1, 11)
    instance_returns = []

    for instance in tqdm(instances, total=10):
        _, h = evaluate_instance(
            env_id, domain, instance, agent, True, 100, verbose=False
        )
        instance_returns.append(list(h))

    stats = EvalEntry(
        batch_id, run_id=run_id, domain=domain, instance_returns=instance_returns
    )
    return stats, agent_path, config_path


def main():
    data_path = sys.argv[1]
    domain = sys.argv[2]
    batch_id = sys.argv[3]
    instances = json.loads(sys.argv[4])
    stats, apath, config_path = test_saved_data(domain, data_path, batch_id)
    to_print = asdict(stats)
    to_print["agent_path"] = apath
    to_print["config_path"] = config_path
    to_print["train_instances"] = instances

    print(json.dumps(to_print))


if __name__ == "__main__":
    main()
