import numpy as np
import json
import pandas as pd
import torch
from tqdm import tqdm


data_folder = "data/datasets/"

class Mapping:
    def __init__(self, list_object):
        n = len(list_object)
        self.hash_table = {}
        for i in range(n):
            hash = self.hash(list_object[i])
            self.hash_table[hash] = i

    def hash(self, object):
        if isinstance(object, dict):
            return (
                json.dumps(object, sort_keys=True)
            )
        elif isinstance(object, list):
            return json.dumps(object)
        elif isinstance(object, str):
            return object
        elif isinstance(object, int):
            return str(object)

    def __call__(self, object):
        hash = self.hash(object)
        try:
            return self.hash_table[hash]
        except Exception:
            raise Exception("object not in hash table")


def tabular_Q_belman(
    dataset,
    goal,
    env_id,
    gamma=1,
    alpha=0.1,
    max_epochs=10,
    Q=None,
    mapping_state=None,
    mapping_action=None,
    verbose=False,
    device="cuda:" 
):
    dtype = torch.float16

    df_cluster = pd.read_csv(f"{data_folder}{dataset}/env/{env_id}/{goal}/data_with_reward.csv")
    df_cluster.drop_duplicates(inplace=True)


    states = df_cluster["obs"]
    actions = df_cluster["action"]
    next_state = df_cluster["next_obs"]

    rewards = df_cluster["reward"].to_list()
 
    
    list_states = states.drop_duplicates().to_list() + next_state.drop_duplicates().to_list()
    list_states = list(set(list_states))
    list_actions = actions.drop_duplicates().to_list()

   
 
    
    mapping_state = Mapping(list_states)
    mapping_action = Mapping(list_actions)
  
    states = states.to_list()
    next_state = next_state.to_list()
    actions = [int(a) for a in actions.to_list()]

    size_a = len(list_actions)
    size_s = len(list_states)


    R = torch.zeros((size_s, size_a), dtype=dtype).to(device)
    P = torch.zeros((size_s, size_a, size_s), dtype=dtype).to(device)
    for t in tqdm(range(len(states)), desc="filling R and P"):

        R[mapping_state(states[t]), mapping_action(actions[t])] = rewards[t]
        P[
            mapping_state(states[t]),
            mapping_action(actions[t]),
            mapping_state(next_state[t]),
        ] = 1

    if Q is None:
        Q = R.clone()
    else:
        Q = torch.tensor(Q, dtype=dtype).to(device)
    
    
    if R.abs().sum().item() == 0:
        raise Exception("R is null")
    

    alpha = torch.tensor(alpha).to(device)
    gamma = torch.tensor(gamma).to(device)
    one = torch.tensor(1).to(device)
    patience = 0
    for i in tqdm(range(max_epochs), desc="Epochs Q iteration"):
        old_Q = Q.clone()
        Q = (one - alpha) * Q + alpha * (
            R + gamma * (torch.matmul(P, torch.max(Q, axis=1)[0]))
        )
        if verbose:
            print("epochs : ", i, "delta : ",torch.max(torch.abs(Q.cpu()-old_Q.cpu())) , "Q mean :", torch.mean(Q))
        if np.isnan(torch.mean(Q).cpu()):
            raise Exception("Nan value")
        if torch.max(torch.abs(Q.cpu()-old_Q.cpu())) ==0:
            patience += 1
            if patience > 10:
                break
        else:
            patience = 0


    if torch.max(torch.abs(Q.cpu()-old_Q.cpu())) > 1e-6:
        raise Exception("Q did not converge")
    return Q.cpu(), list_states, list_actions
