import gym
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from NN import *
from FQE_utils import get_screen, resize_dataset
    
def integrateV(policy_net, Q, s, s_internal):
    """
        Compute the integral in Q's regression.
        Output: a b-dimensional numpy vector
    """
    with torch.no_grad():
        pi_vec = policy_net.get_pi_vec(s_internal)
        Q_vec = Q.get_integral_vecs(s)
    result = torch.sum(pi_vec * Q_vec, 1).numpy()
    return result

def train(net, trainloader, optimizer, max_epoch=20):
    """
        Train an NN using data.
    """
    criterion = nn.MSELoss()
    for epoch in range(max_epoch):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, actions, labels = data
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = net(inputs)   # numpy arrays are transformed into torch tensors automatically by DataLoader
            outputs = outputs.gather(1, actions)
            loss = criterion(outputs, labels)   # compute loss and gradient for the chosen actions only
            loss.backward()
            optimizer.step()
    return net
    
def FQEV(env, D, target_policy_net, resize, screen_height, screen_width, Q_training_param, H, Q=None, sample_s=1000):
    Q_max_epoch = Q_training_param['max_epoch']
    Q_lr = Q_training_param['lr']
    Q_batch_size = Q_training_param['batch_size']
    m = net_param_num(target_policy_net)
    nA = env.action_space.n
    
    # intialize nets
    if Q is None:
        Q = QVNet(screen_height, screen_width, nA)
    
    for h in range(H, 0, -1):
        d_h = D[h]
        d_h = resize_dataset(resize, d_h)
        sample_set = np.array([e[0][0] for e in d_h]) 
        actions_set = np.array([[e[1]] for e in d_h])
        next_state_lst, next_state_internal_lst = np.array([e[2][0] for e in d_h]), np.array([e[2][1] for e in d_h])
        reward_vec = np.array([e[3] for e in d_h])
        Q_label_set = reward_vec + integrateV(target_policy_net, Q, next_state_lst, next_state_internal_lst)
        Q_label_set = np.expand_dims(Q_label_set, axis=1)
        training_data = CustomDataset(sample_set, actions_set, Q_label_set)
        trainloader = DataLoader(training_data, batch_size=Q_batch_size, shuffle=True)
        new_Q = QVNet(screen_height, screen_width, nA)
        new_Q.load_state_dict(Q.state_dict())
        Q_optimizer = optim.SGD(new_Q.parameters(), lr=Q_lr, momentum=0.9)
        Q = train(new_Q, trainloader, Q_optimizer, max_epoch=Q_max_epoch)
        del d_h
        
    result = None
    if sample_s > 0:
        sampled_s_lst, sampled_s_internal_lst = [], []
        for _ in range(sample_s):
            state_inner = env.reset()
            last_screen = get_screen(env, resize)
            current_screen = get_screen(env, resize)
            state = current_screen - last_screen
            sampled_s_lst.append(state)
            sampled_s_internal_lst.append(state_inner)
        sampled_s_lst, sampled_s_internal_lst = np.array(sampled_s_lst), np.array(sampled_s_internal_lst)
        result = np.sum(integrateV(target_policy_net, Q, sampled_s_lst, sampled_s_internal_lst)) / sample_s
    return result, Q
    
