import pickle
import gym
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from PG_utils import set_seed

def flatten(lst):
    result = []
    for sublst in lst:
        result.extend(sublst)
    return result

def load_result(filename):
    f = open(filename, 'rb')
    result = pickle.load(f)
    f.close()
    return result

def store_result(obj, filename):
    f = open(filename, 'wb')
    pickle.dump(obj, f)
    f.close()
    
class StreamingDataset:
    def __init__(self, filename, H, K=-1):
        self.H = H
        self.filename = filename
        self.K = K
    
    def __len__(self):
        return self.H
        
    def __getitem__(self, h):
        if self.K == -1:
            return load_result(self.filename + '_h-' + str(h) + '.pickle')
        else:
            return load_result(self.filename + '_h-' + str(h) + '.pickle')[:self.K]
    
def sample_iid_visual(env, resize, policy_net, eps, K, start_h, end_h, filename, seed_init=1000):
    # policy_net is based on internal state on the environment, not visual
    # D consists of H lists, each with size K. States are a tuple containing the visual, a [1, res] torch tensor, and the internal state, a 4-dimensional numpy vector.
    for h in range(start_h, end_h+1):
        set_seed(env, seed_init + h)
        d_h = []
        for i_episode in range(1, K + 1):
            # Generate an episode: an episode is an array of (state, action, next_state, reward) tuples
            state_inner = env.reset()
            last_screen = get_screen(env, resize)
            current_screen = get_screen(env, resize)
            state = current_screen - last_screen
            done, h_curr = False, 1
            while (not done) and (h_curr <= h):
                action = policy_net.get_action(state_inner, eps=eps)
                next_state_inner, reward, done, _ = env.step(action)
                # Observe new state
                last_screen = current_screen
                current_screen = get_screen(env, resize)
                next_state = current_screen - last_screen
                # absorbing state with reward 0 if done
                if h_curr == h:
                    d_h.append(((state, state_inner), action, (next_state, next_state_inner), reward))
                state = next_state
                state_inner = next_state_inner
                h_curr += 1
            if done and (h_curr <= h):
                action = policy_net.get_action(state_inner, eps=eps)
                d_h.append(((state, state_inner), action, (state, state_inner), 0.0))
        store_result(d_h, filename + '_h-' + str(h) + '.pickle')

# Modified from https://github.com/pytorch/tutorials/blob/master/intermediate_source/reinforcement_q_learning.py
def get_screen(env, resize):
    screen = env.render(mode='rgb_array').transpose((2, 0, 1))
    # Cart is in the lower half, so strip off the top and bottom of the screen
    _, screen_height, screen_width = screen.shape
    screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
    
    # Convert to float, rescale, convert to torch tensor
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    screen = torch.from_numpy(screen)
    # Resize, and add a batch dimension (BCHW)
    result = resize(screen).numpy()
    return result
    
def resize_dataset(resize, d_h):
    result = []
    for e in d_h:
        state, next_state = e[0][0], e[2][0]
        state, next_state = torch.from_numpy(state), torch.from_numpy(next_state)
        state, next_state = resize(state).numpy(), resize(next_state).numpy()
        result.append(((state, e[0][1]), e[1], (next_state, e[2][1]), e[3]))
    return result
    