import numpy as np
import torch
from tqdm import tqdm, trange
import matplotlib.pyplot as plt

import itertools
import matplotlib.pyplot as plt
import argparse
import pickle
import os
import multiprocessing as mp
from functools import partial
import time
import ast
import json
import copy
from .utils import runEpisode, offlineRL, control_seed, DEFAULT_DEVICE
from .base_selection import base_selection

class direct_selection(base_selection):
    def __init__(self,):
        super().__init__()
        self.selection_name = "direct"


    # def iqltrain(self, next_visit_ids):

    #     # Get selected states (scalar IDs) from unique_obs
    #     selected_state_bits = self.unique_obs[next_visit_ids]

    #     # iql_dataset_indices = np.zeros(self.total_sample, dtype=bool)
    #     zero_indice = []
    #     for selected_state in selected_state_bits:
    #         selected_state = tuple(selected_state) if self.packbits else selected_state
    #         zero_indice += self.state_visitation[selected_state]
    #     # iql_dataset_indices[zero_indice] = True

    #     # Update training indices and visited states
    #     # self.train_inds = np.logical_or(self.train_inds, iql_dataset_indices)

    #     self.train_inds_list += zero_indice
    #     self.visited_ids += next_visit_ids
    #     selected_state_bits = self.unique_obs[self.visited_ids]

    #     # print(len(self.train_inds_list), len(self.visited_ids), self.impute)
        
    #     if self.impute is None:
    #         iqldataset = {k: v[self.train_inds_list] for k, v in self.dataset.items()}
    #     else:
    #         all_indices = set(range(self.total_sample))
    #         train_indices_set = set(self.train_inds_list)
    #         indices_to_impute = list(all_indices - train_indices_set)

    #         iqldataset = copy.deepcopy(self.dataset)
    #         iqldataset['rewards'][indices_to_impute] = self.impute

    #     qtable = self.qfunction.train(iqldataset)

    #     # sorted_dict = dict(sorted({k: np.argmax(v) for k, v in qtable.items()}.items()))
    #     # print(sorted_dict)
    #     # breakpoint()

    #     if self.impute is not None:
    #         acc = self.get_impute_acc(qtable)
    #     else:
    #         acc = 0.

    #     ilagent = self.il.train(selected_state_bits, self.unique_obs) 

    #     IQL_agent = offlineRL(qtable, ilagent, self.unique_obs, self.total_actions, self.packbits, self.impute)

    #     Js = 0
    #     for i in range(self.eval_episodes):
    #     # for i in trange(self.eval_episodes):
    #         Js += runEpisode(self.env, IQL_agent)
    #     Js = Js / self.eval_episodes
        
    #     return Js, acc
    
    def run(self):

        if self.impute is not None:
            fp = f"{self.root}/{self.env.name}/model/best"
            if os.path.exists(fp):
                with open(fp, "rb") as f:
                    self.bestq = pickle.load(f)
            else:
                self.bestq = self.qfunction.train(self.dataset)
                os.makedirs(os.path.dirname(fp), exist_ok=True)
                with open(fp, "wb") as f:
                    pickle.dump(self.bestq, f)
                    
        if isinstance(self.visit_ids, str):
            self.next_visit_ids = list(map(int, self.visit_ids.split("_")))
        else:
            self.next_visit_ids = [self.visit_ids]
        
        # Js, acc = self.iqltrain(self.next_visit_ids)

        Js = []
        acc = []
        for _ in range(3):
            self.train_inds_list = []
            self.visited_ids = []
            J, a = self.iqltrain(self.next_visit_ids)
            # print(len(self.train_inds_list), len(self.visited_ids))
            Js.append(J)
            acc.append(a)
        # print(Js)
        Js = np.mean(Js)
        acc = np.mean(acc)

        if self.save_result:
            self.save_file(Js, acc)
            # print(np.mean(Js), self.visit_ids)
        else:
            # breakpoint()
            # print(len(self.train_inds_list), len(self.visited_ids))
            # print(np.mean(Js))
            print(Js)
            # breakpoint()
            
                
    # def get_impute_acc(self, qtable):
    #     preds = {k: np.argmax(v) for k, v in qtable.items()}
    #     truths = {k: np.argmax(v) for k, v in self.bestq.items()}
        
    #     matches = sum(preds[k] == truths[k] for k in preds)
    #     accuracy = matches / len(preds)
    #     return accuracy.item()

    def save_file(self, Js, acc):
        
        # Use hashed_id if provided in selection_params, otherwise use visit_ids
        file_id = self.selection_params.get('hashed_id', self.selection_params.visit_ids)
        
        savepath = f"{self.save_root}/{self.budget}"
        os.makedirs(savepath, exist_ok=True)
        
        savename = f"{savepath}/{file_id}.npy"
        np.save(savename, np.array([Js, acc]))

        savepath = f"{self.save_root}_{self.search}/{len(self.next_visit_ids)}/{self.visit_ids}.npy"
        os.makedirs(os.path.dirname(savepath), exist_ok=True)

        np.save(savepath, np.array([Js, acc]))







