import numpy as np
import torch
import gym
import time
from .dictlist import DictList
import utils 

def default_preprocess_obss(obss, device=None):
    obss = np.array(obss)
    return torch.tensor(obss, device=device, dtype=torch.float)

def get_obss_preprocessor(obs_space):
    # Check if obs_space is an image space
    if isinstance(obs_space, gym.spaces.Box):
        obs_space = {"image": obs_space.shape}

        def preprocess_obss(obss, device=None):
            return DictList({
                "image": preprocess_images(obss, device=device)
            })

    # Check if it is a MiniGrid or modified Bullet observation space
    elif isinstance(obs_space, gym.spaces.Dict) and "image" in list(obs_space.spaces.keys()):
        obs_space = {"image": obs_space.spaces["image"].shape}

        def preprocess_obss(obss, device=None):
            return DictList({
                "image": preprocess_images([obs["image"] for obs in obss], device=device),
            })

    else:
        raise ValueError("Unknown observation space: " + str(obs_space))

    return obs_space, preprocess_obss


def preprocess_images(images, device=None):
    # Bug of Pytorch: very slow if not first converted to numpy array
    images = np.array(images)
    return torch.tensor(images, device=device, dtype=torch.float)


def flatten_grads(parameters):
    """
    flattens all grads into a single column vector. Returns indeices to recover them
    :param: parameters: a generator or list of all the parameters
    :return: a dictionary: {"params": [#params, 1],
    "indices": [(start index, end index) for each param] **Note end index in uninclusive**

    """
    l = [torch.flatten(p.grad) for p in parameters]
    indices = []
    s = 0
    for p in l:
        size = p.shape[0]
        indices.append((s, s+size))
        s += size
    flat = torch.cat(l).view(-1, 1)
    return flat, indices


def recover_flattened(flat_params, indices, model):
    """
    Gives a list of recovered parameters from their flattened form
    :param flat_params: [#params, 1]
    :param indices: a list detaling the start and end index of each param [(start, end) for param]
    :param model: the model that gives the params with correct shapes
    :return: the params, reshaped to the ones in the model, with the same order as those in the model
    """
    l = [flat_params[s:e] for (s, e) in indices]
    for i, p in enumerate(model.parameters()):
        l[i] = l[i].view(*p.shape)
    return l


def handle_logs(logs,
                txt_logger,
                tb_writer,
                csv_file, 
                csv_logger,
                update, 
                start_time,
                num_frames, 
                status,
                update_start_time, 
                update_end_time):
    fps = logs["num_frames"]/(update_end_time - update_start_time)
    duration = int(time.time() - start_time)

    header = ["update", "frames", "FPS", "duration"]
    out_str = "U-{} | F {:06} | FPS {:04.0f} | D {} | "
    data = [update, num_frames, fps, duration]

    return_per_episode = []
    logs_return_per_episode = np.array(logs["return_per_episode"])
    for i in range(logs_return_per_episode.shape[1]):
        out_str += "RPE:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | "
        return_per_episode = utils.synthesize(logs_return_per_episode[:,i])
        header += ["RPE" + str(i + 1) + "_" + key for key in return_per_episode.keys()]
        data += return_per_episode.values()

    header += ["entropy","critic_loss", "actor_loss", "loss", "grad_norm", "epsilon", "mu"]
    data += [logs["entropy"], logs["critic_loss"], logs["actor_loss"], logs["loss"], logs["grad_norm"], logs["epsilon"], logs["mu"]]
    out_str += "H {:.3f} | cL {:.3f} | aL {:.3f} | L {:.3f} | ∇ {:.3f} | eps {:.3f} | mu {:.3f}"

    txt_logger.info(out_str.format(*data))

    # Write tensorboard stuff
    for field, value in zip(header, data):
        tb_writer.add_scalar(field, value, num_frames)


    # write csv
    if status["num_frames"] == 0:
        csv_logger.writerow(header)

    csv_logger.writerow(data)
    csv_file.flush()

    return return_per_episode["mean"]

