import torch
import torchvision
import numpy as np
import torch.nn as nn
from torchvision import transforms as transforms


relu = nn.ReLU()


def sort_and_k_least(array, k_least):
    _, index = torch.sort(array, descending=False)
    return index[:k_least].numpy()


def sample_rewards_from_ensemble(method, n_samples, states, actions, models_numbers, models_dir, batch_size, device):
    '''
    Sample reward functions from deep ensamble and calculate the expected value of the policy.

    Arguments:
    method: (str) "test" to play k-of-N on a sample by sample basis, "TEST" to play k-of-N for full dataset
    n_samples: (int) number of samples
    states: (Tensor) states on which the policy will be evaluated
    actions:(Tensor) policy
    models_numbers:(list or 1d numpy array) model number pool from which to sample
    models_dir: (str) directory path where trained models are saved
    batch_size: (int) batch size
    device: (str) device "cuda:0" or "cuda:1" or "cpu"

    Returns:
    Tensor shape (n_samples, ): expected value for each sample
    Tensor shape (n_samples, states_size, action_size): reward for each sample and each state
    1d numpy array: sampled model numbers
    1d numpy array: model numbers not sampled
    '''
    ms = np.random.choice(models_numbers, n_samples)
    n_rs = torch.zeros((n_samples, states.shape[0], actions.shape[1]))

    if method == "TEST":
        estimated_rewards = torch.zeros((n_samples,))
    else:
        estimated_rewards = torch.zeros((n_samples, states.shape[0]))

    for s in range (n_samples):
        model = torch.load(models_dir+"/ensemble_model_{}".format(ms[s])).to(device)
        for batch in range (0, states.shape[0] , batch_size):
            n_rs[s, batch : batch+batch_size] = (model(states[batch : batch+batch_size].to(device))).detach().cpu()
        if method == "TEST":
            estimated_rewards[s] = torch.sum(n_rs[s]*actions).item()
        else:
            estimated_rewards[s] = torch.sum(n_rs[s]*actions, 1)
    return estimated_rewards, n_rs, ms, np.setdiff1d(models_numbers, ms)


def check_dataset(dataset):
    """
    Load dataset

    Arguments:
    dataset: (str) "E-MNIST"  or "E-MNIST", "MNIST-Fashion"

    Returns:
    state: (tensor) dataset output
    """
    transform = transforms.ToTensor()

    if dataset == "MNIST":
        mnist_test = torchvision.datasets.MNIST('datasets', train=False, download=True, transform=transform)
        testloader = torch.utils.data.DataLoader(mnist_test,  shuffle=False)
        state = torch.zeros((len(testloader) , 1, 28, 28))
        for i, data in enumerate (testloader):
            img, label = data
            state[i] = img.view(1,28,28)
    elif dataset == "E-MNIST":
        emnist_test  = torchvision.datasets.EMNIST(root="datasets", train=False, transform=transform,
                                                   target_transform=None, download=True, split="letters")
        state = torch.zeros((len(emnist_test), 1, 28, 28))
        for i in range (len(emnist_test)):
            state[i]  = torch.transpose(emnist_test[i][0], 1, -1).view(1, 28, 28)
    elif dataset == "MNIST-Fashion":
        fashion_mnist_test = torchvision.datasets.FashionMNIST(root="datasets", train=False,
                                                               transform=transform, target_transform=None, download=True)
        state = torch.zeros((len(fashion_mnist_test), 1, 28, 28), device="cpu")
        for i in  range (len(fashion_mnist_test)):
            state[i]  = fashion_mnist_test[i][0].view(1, 28, 28)
    else:
        raise ValueError("dataset shoud be MNIST, E-MNIST or Fashion-MNIST")
    return state


def run_k_of_n(ks, ns, n_runs, n_itr, method, n_models,batch_size, models_dir,
               output_policies_dir, device, dataset, n_actions):
    '''
    Run k-of-N CFR.

    Arguments:
    ks: (list) k values
    ns: (list) n values
    n_runs: (int) the number of times to run k-of-N CFR
    n_itr: (int) the number of k-of-N CFR iterations
    method: (str) "test" to play k-of-N on a sample by sample basis, "TEST" play k-of-N for full dataset
    n_models: (int) number of models in Enamble
    batch_size: (int) batch size
    models_dir: (str) directory path where deep models have been saved
    output_policies_dir: (str) directory path where you want to save k-of-N policies
    device: (str) device "cuda:0" or "cuda:1" or "cpu"
    dataset: (str) "MNIST" or "E-MNIST"
    n_actions: (int) the number of actions
    '''
    state = check_dataset(dataset)
    for i in range (len(ks)):
        k = ks[i]
        n = ns[i]
        for run in range(n_runs):
            expected_value = np.zeros(n_itr)
            actions = torch.softmax(torch.ones((state.shape[0], n_actions), device="cpu"), dim=1)
            total_regret = torch.zeros((actions.shape[0] , actions.shape[1]), device="cpu")
            models_numbers = np.arange(0, n_models, 1, dtype=np.int)
            for itr in range (n_itr):
                n_estimated_rewards, n_rs, mss , models_numbers = sample_rewards_from_ensemble (method, n,  state,actions,models_numbers,
                                                                                                models_dir, batch_size, device)
                if method == "TEST":
                    k_index = sort_and_k_least(n_estimated_rewards, k)
                    expected_value[itr] = n_estimated_rewards[k_index].sum()
                    mean_rs  = torch.mean(n_rs[k_index], 0)
                    P_t = mean_rs - torch.mm(torch.sum(actions*mean_rs, 1, dtype=torch.float).view(-1, 1),
                                             torch.ones((1, actions.shape[1]), dtype=torch.float))
                    total_regret += P_t
                else:
                    for s in range (state.shape[0]):
                        k_index = sort_and_k_least(n_estimated_rewards[:, s], k)
                        expected_value[itr] += n_estimated_rewards[k_index, s].sum()
                        single_action = n_rs[:, s]
                        mean_rs  = torch.mean(single_action[k_index], 0)
                        P_t = mean_rs - torch.sum(mean_rs*actions[s])*torch.ones((1, actions.shape[1]))
                        total_regret[s] += P_t.view(-1)

                actions = relu(total_regret)/(torch.ones_like(total_regret)*torch.sum(relu(total_regret), 1).view(-1, 1))

                print("{}-of-{} run no {} iteration number {}".format(k, n, run, itr))

                if dataset == "MNIST":
                    np.save(output_policies_dir + "/run_{}_mnist_actions_{}-of-{}_n_itr_{}".format(run, k,n, n_itr), actions.numpy())
                    np.save(output_policies_dir + "/run_{}_expected_value_mnist_{}-of-{}_n_itr_{}".format(run, k,n, n_itr), expected_value)
                elif dataset == "E-MNIST":
                    np.save(output_policies_dir + "/run_{}_emnist_actions_{}-of-{}_n_itr_{}".format(run, k,n, n_itr), actions.numpy())
                    np.save(output_policies_dir + "/run_{}_expected_value_emnist_{}-of-{}_n_itr_{}".format(run, k,n, n_itr), expected_value)

                elif dataset == "MNIST-Fashion":
                    np.save(output_policies_dir + "/run_{}_fashion_actions_{}-of-{}_n_itr_{}".format(run, k,n, n_itr), actions.numpy())
                    np.save(output_policies_dir + "/run_{}_expected_value_fashion_{}-of-{}_n_itr_{}".format(run, k,n, n_itr), expected_value)
                else:
                    raise ValueError("dataset shoud be MNIST, E-MNIST or Fashion-MNIST")
