import pickle
import random

from models import random_var, shared_utils, reverse_grad
from env_rob_benchmarking import eval_agent_on_mdp, load_agent_api
from overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv

import os
import sys
import torch
import numpy as np
from tqdm import tqdm

from train_agent.benchmarking1 import ApagAgentNewVersion


DEFAULT_HORIZON = 800
DEFAULT_N_GAMES = 10

DEFAULT_DATA_PARAMS = {
    'env_params': {
        "horizon": DEFAULT_HORIZON
    },
    'n_games': DEFAULT_N_GAMES
}


class Processor:
    allowed_method = ['random', 'stay_targeted']

    def __init__(self, target_agent: ApagAgentNewVersion, target_env: OvercookedEnv, meta_data, data_params):
        self.target_agent = target_agent
        self.target_policy = target_agent.actor_critic
        self.target_env = target_env
        self.target_mdp = target_env.mdp
        self.meta_data = meta_data
        self.data_params = data_params

        self.state_fn = self.target_env.lossless_state_encoding_mdp

        self.data = None
        self.processed_grad = None
        self.stat_dict = None
        self.obss = None
        self.processed_act_probs = None

    def process_data(self):
        print("processing data")
        all_states = [r['ep_states'] for r in self.data]
        states = np.concatenate(all_states)
        states = np.concatenate(states)
        obss = [self.state_fn(state) for state in states]
        obss = np.concatenate(obss)
        self.obss = obss

        all_grads = []
        all_act_probs = []
        for obs in tqdm(obss):
            grad, act_prob = shared_utils.get_reverse_grad(self.target_policy, obs)
            all_grads.append(grad)
            all_act_probs.append(act_prob)
        self.processed_grad = np.array(all_grads)
        self.processed_act_probs = np.array(all_act_probs)

    def prepare_data(self, mode='sp'):
        print("collecting data")
        env_params = self.data_params['env_params']
        n_games = self.data_params['n_games']
        data_bases = []
        if mode == 'sp':
            target_agent = self.target_agent
            mean, se, rollouts = eval_agent_on_mdp(target_agent, target_agent, self.target_mdp, env_params, n_games)
            data_bases.append(rollouts)
        self.data = data_bases

        print("calculating choices")
        stat_dict = shared_utils.get_feasible_stat(self.target_mdp)
        self.stat_dict = stat_dict

    def get_attack(self, attack_method='stay_targeted', epi=2, top_k=10, verbose=False, debug=False, f_thres=0.03):
        print(f"performing attack {attack_method} with epi={epi}, top {top_k}")
        if attack_method == 'random':
            all_twisted_states, choices_diffs = shared_utils.feasible_set_to_choice(self.target_env, self.stat_dict,
                                                                                    epi)
            print(f"raw choices = {len(all_twisted_states)}")
            picked_states = random.sample(all_twisted_states, top_k)
        elif attack_method == 'random+':
            all_twisted_states, choices_diffs = shared_utils.feasible_set_to_choice(self.target_env, self.stat_dict,
                                                                                    epi)
            good_idx = shared_utils.state_filter(all_twisted_states, choices_diffs, self.obss, f_thres)
            print(f"filtered choices = {len(good_idx)} , with threshold = {f_thres}")
            filtered_twisted_states = [all_twisted_states[idx] for idx in good_idx]
            picked_states = random.sample(filtered_twisted_states, top_k)
        elif attack_method == 'stay_targeted':
            all_twisted_states, choices_diffs = shared_utils.feasible_set_to_choice(self.target_env, self.stat_dict, epi)
            print(f"raw choices = {len(choices_diffs)} ")
            # 22.11.26 + filter
            good_idx = shared_utils.state_filter(all_twisted_states, choices_diffs, self.obss, f_thres)
            print(f"filtered choices = {len(good_idx)} , with threshold = {f_thres}")
            filtered_twisted_states = [all_twisted_states[idx] for idx in good_idx]
            filtered_choices_diffs = [choices_diffs[idx] for idx in good_idx]

            rks = reverse_grad.stay_target(self.processed_grad, filtered_choices_diffs, top_k)
            picked_states = [filtered_twisted_states[idx] for idx in rks]
            # picked_states = [all_twisted_states[idx] for idx in rks]
            if debug:
                dbg_data = [self.data, self.processed_grad, all_twisted_states, choices_diffs, rks]
                pickle.dump(dbg_data, open(f'{attack_method}_debug.pkl', 'wb'))
        elif attack_method == 'no_target':
            all_twisted_states, choices_diffs = shared_utils.feasible_set_to_choice(self.target_env, self.stat_dict, epi)
            print(f"raw choices = {len(choices_diffs)} ")
            # 22.11.26 + filter
            good_idx = shared_utils.state_filter(all_twisted_states, choices_diffs, self.obss, f_thres)
            print(f"filtered choices = {len(good_idx)} , with threshold = {f_thres}")
            filtered_twisted_states = [all_twisted_states[idx] for idx in good_idx]
            filtered_choices_diffs = [choices_diffs[idx] for idx in good_idx]

            rks = reverse_grad.no_target(self.processed_grad, filtered_choices_diffs, self.processed_act_probs, top_k)
            picked_states = [filtered_twisted_states[idx] for idx in rks]
            # picked_states = [all_twisted_states[idx] for idx in rks]
            if debug:
                dbg_data = [self.data, self.processed_grad, all_twisted_states, choices_diffs, rks]
                pickle.dump(dbg_data, open(f'{attack_method}_debug.pkl', 'wb'))
        else:
            print(f'{attack_method} not implemented, now {Processor.allowed_method}')
            raise NotImplementedError

        return picked_states
