import pickle
import random
import os
import sys
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

from overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
from overcooked_ai_py.agents.benchmarking import AgentEvaluator
from overcooked_ai_py.agents.agent import AgentPair

from models import random_var, shared_utils, reverse_grad
from env_rob_benchmarking import eval_agent_on_mdp, load_agent_api
from a2c_ppo_acktr.model import Policy


def distillation(y, teacher_scores, labels, T, alpha):
    p = F.log_softmax(y/T, dim=1)
    q = F.softmax(teacher_scores/T, dim=1)
    l_kl = F.kl_div(p, q, size_average=False) * (T**2) / y.shape[0]
    l_ce = F.cross_entropy(y, labels)
    return l_kl * alpha + l_ce * (1. - alpha)


def get_disturbed_obss(start_states, ae: AgentEvaluator):
    standard_s = ae.env.mdp.get_standard_start_state()
    standard_obs_0 = ae.env.lossless_state_encoding_mdp(standard_s)[0]
    all_delta_s = []
    for st in start_states:
        st_obs0 = ae.env.lossless_state_encoding_mdp(st)[0]
        delta_s = st_obs0 - standard_obs_0
        all_delta_s.append(delta_s)
    all_delta_s = np.array(all_delta_s)
    return all_delta_s


def process_data_for_kd(rollouts, target_policy: Policy, ae: AgentEvaluator, T=2, slack_T=4,  bs=8000, device='cuda:0', verbose=False):
    if verbose:
        print("processing data")
    all_states = rollouts['ep_states']
    states = np.concatenate(all_states)  # combine 2 positions
    state_fn = ae.env.lossless_state_encoding_mdp
    obss = [state_fn(state) for state in states]
    obss = np.concatenate(obss)

    idxs = list(range(0, obss.shape[0], bs)) + [obss.shape[0]]
    n_batches = obss.shape[0] // bs
    target_policy.eval()

    values = []
    act_probs = []
    slack_act_probs = []
    for i in range(n_batches):
        in_obss = obss[idxs[i]: idxs[i + 1]]
        obs_batch = torch.tensor(in_obss, dtype=torch.float).to(device)
        # recurrent_hidden_states_batch = None
        value, actor_features, _ = target_policy.base(obs_batch, None, None)
        logits = target_policy.dist.linear(actor_features)
        t_probs = F.softmax(logits / T, dim=1)
        slack_t_probs = F.softmax(logits / slack_T, dim=1)

        values.append(value.cpu().detach())
        act_probs.append(t_probs.cpu().detach())
        slack_act_probs.append(slack_t_probs.cpu().detach())

    values = torch.concat(values)
    act_probs = torch.concat(act_probs)
    slack_act_probs = torch.concat(slack_act_probs)
    obss = torch.tensor(obss)
    return obss, act_probs, slack_act_probs, values


def prepare_data(n_games, target_agent, ae, mode='sp', verbose=False):
    if verbose:
        print("collecting data")

    if mode == 'sp':
        a_pair = AgentPair(target_agent, target_agent, allow_duplicate_agents=True)
        rollouts = ae.evaluate_agent_pair(a_pair, num_games=n_games, info=verbose)
    else:
        raise NotImplementedError

    return rollouts


def stat_combine_groups(nums, means, stds):
    assert len(nums) == 2, NotImplementedError

    n1, n2 = nums
    m1, m2 = means
    s1, s2 = stds
    n = n1 + n2
    m = (m1 * n1 + m2 * n2) / n
    term1 = n1 * np.power(s1, 2)
    term2 = n2 * np.power(s2, 2)
    term3 = (n1 * n2) / n * np.power(m1 - m2, 2)
    s = np.sqrt((term1 + term2 + term3) / n)

    return n, m, s


def multi_stat_combine_groups(nums, means, stds):
    # assert len(nums) >= 2
    if len(nums) < 2:
        return nums[0], means[0], stds[0]

    n, m, s = stat_combine_groups(nums[:2], means[:2], stds[:2])
    for i in range(2, len(nums)):
        n, m, s = stat_combine_groups((n, nums[i]), (m, means[i]), (s, stds[i]))

    return n, m, s
