import torch
import pyspiel
import numpy as np
from bomb.CKA import bc_model_cka, mb_model_cka, get_similarity
from network.ensemble_mb_bc import mix_policy
from open_spiel.python import policy
from open_spiel.python.algorithms import exploitability


class alphamodel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(alphamodel, self).__init__()
        self.state = torch.nn.Linear(input_dim, hidden_dim)
        self.out = torch.nn.Linear(hidden_dim, output_dim)

    # for use
    def forward(self, state):
        x = torch.relu(self.state(state))
        x = torch.sigmoid(self.out(x))
        return x

class deep_cfr_policy:
    def __init__(self, policy_model, device):
        self._policy = policy_model.to(device)
        self._softmax_layer = torch.nn.Softmax(dim=-1).to(device)
        self.device = device

    # state
    def __call__(self, state):
        info_state_vector = np.array(state.information_state_tensor())
        if len(info_state_vector.shape) == 1:
            info_state_vector = np.expand_dims(info_state_vector, axis=0)
        info_state_vector = torch.FloatTensor(info_state_vector).to(self.device)

        strategy = self._policy(info_state_vector)
        strategy = self._softmax_layer(strategy).cpu()
        return strategy.squeeze(0).detach().numpy()

    # state is tensor vector
    def step(self, state_vector):
        with torch.no_grad():
            strategy = self._policy(state_vector)
            strategy = self._softmax_layer(strategy)
        return strategy

def tabular_policy_from_callable(game, behavior_policy, device, players=None):
    tabular_policy = policy.TabularPolicy(game, players)
    for state_index, state in enumerate(tabular_policy.states):
        cur_player = state.current_player()
        legal_actions = state.legal_actions(cur_player)
        info_state_vector = np.array(state.information_state_tensor())
        if len(info_state_vector.shape) == 1:
            info_state_vector = np.expand_dims(info_state_vector, axis=0)
        info_state_vector = torch.FloatTensor(info_state_vector).to(device)

        strategy = behavior_policy.step(info_state_vector).squeeze(0).tolist()

        action_probabilities = {action: strategy[action] for action in legal_actions}

        infostate_policy = [action_probabilities.get(action, 0.) for action in range(game.num_distinct_actions())]
        tabular_policy.action_probability_array[state_index, :] = infostate_policy
    return tabular_policy

def select_alpha(game, bc_location, mb_location, device):
    weights_list = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    results = []
    for proportion in range(11):
        bc_model = bc_location.format(proportion)
        mb_model = mb_location.format(proportion)
                
        bc_model = torch.load(bc_model).to(device)
        mb_model = deep_cfr_policy(torch.load(mb_model).to(device), device=device)
        
        conv_list = []
        for w in weights_list:
            policy_list = mix_policy(bc_model, mb_model, bc_weight=w)
            # compute nash_cov
            average_policy = tabular_policy_from_callable(game, policy_list, device)
            conv = exploitability.nash_conv(game, average_policy)
            conv_list.append(conv)

        results.append(min(conv_list))
        min_weights = weights_list[conv_list.index(min(conv_list))]
    
    return results, min_weights

# game used to train predictor
device = "cuda" if torch.cuda.is_available() else "cpu"
game_name = "kuhn_poker"
n_players = 2

game = pyspiel.load_game(game_name, {"players": n_players})

bc_location = "./mix_dataset_bc_policy/kuhn_poker_2_players/policy_proportion_{}/seed_1_hidden_layer_32_buffer_1000_lr_0.05_train_epoch_1000_batch_size_32_policy.pkl"
mb_location = "./mix_dataset_mb_policy/kuhn_poker_2_players/train_data_1000/seed_2_policy_train_epoch_2000_proportion_{}.pkl"

# compute CKA similarity
similarity = []
for i in range(11):
    bc_model_1 = torch.load(bc_location.format(i), map_location=device)
    mb_model_1 = torch.load(mb_location.format(i), map_location=device)

    mb_model_1 = mb_model_cka(mb_model_1)
    bc_model_1 = bc_model_cka(bc_model_1)

    temp = 0
    similarity_value = get_similarity(game, bc_model_1, mb_model_1)
    similarity.append(similarity_value)

# compute the optimal alpha value
alpha_results = select_alpha(game, bc_location, mb_location, device)
x = similarity
y = alpha_results[0]

# train predictor model
model = alphamodel(4, 64, 1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = torch.nn.MSELoss()

x_input = torch.FloatTensor(x).to(device)
y_input = torch.FloatTensor(y).unsqueeze(1).to(device)

for epoch in range(5000):
    prediction = model(x_input)
    loss = loss_func(prediction, y_input)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(loss)

torch.save(model, "./{}_{}_player_alpha_model.pkl".format(game_name, n_players))