import torch
import numpy as np
from torch.distributions import Categorical
from torch.autograd import Variable
import scipy.signal
import gym
from gym.envs.classic_control.cartpole import CartPoleEnv
import pygame
from pygame import gfxdraw


class CartPoleEnvR(CartPoleEnv):
    def __init__(self):
        super().__init__()
        self.screen = None
        self.clock = None
        self.isopen = True
    
    def render(self, mode="human"):
        screen_width = 600
        screen_height = 400

        world_width = self.x_threshold * 2
        scale = screen_width / world_width
        polewidth = 10.0
        polelen = scale * (2 * self.length)
        cartwidth = 50.0
        cartheight = 30.0

        if self.state is None:
            return None

        x = self.state

        if self.screen is None:
            pygame.init()
            pygame.display.init()
            self.screen = pygame.display.set_mode((screen_width, screen_height))
        if self.clock is None:
            self.clock = pygame.time.Clock()

        self.surf = pygame.Surface((screen_width, screen_height))
        self.surf.fill((255, 255, 255))

        l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
        axleoffset = cartheight / 4.0
        cartx = x[0] * scale + screen_width / 2.0  # MIDDLE OF CART
        carty = 100  # TOP OF CART
        cart_coords = [(l, b), (l, t), (r, t), (r, b)]
        cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
        gfxdraw.aapolygon(self.surf, cart_coords, (0, 0, 0))
        gfxdraw.filled_polygon(self.surf, cart_coords, (0, 0, 0))

        l, r, t, b = (
            -polewidth / 2,
            polewidth / 2,
            polelen - polewidth / 2,
            -polewidth / 2,
        )

        pole_coords = []
        for coord in [(l, b), (l, t), (r, t), (r, b)]:
            coord = pygame.math.Vector2(coord).rotate_rad(-x[2])
            coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
            pole_coords.append(coord)
        gfxdraw.aapolygon(self.surf, pole_coords, (202, 152, 101))
        gfxdraw.filled_polygon(self.surf, pole_coords, (202, 152, 101))

        gfxdraw.aacircle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )
        gfxdraw.filled_circle(
            self.surf,
            int(cartx),
            int(carty + axleoffset),
            int(polewidth / 2),
            (129, 132, 203),
        )

        gfxdraw.hline(self.surf, 0, screen_width, carty, (0, 0, 0))

        self.surf = pygame.transform.flip(self.surf, False, True)
        self.screen.blit(self.surf, (0, 0))
        if mode == "human":
            pygame.event.pump()
            self.clock.tick(self.metadata["render_fps"])
            pygame.display.flip()

        if mode == "rgb_array":
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
            )
        else:
            return self.isopen

def set_seed(env, seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    env.seed(seed)


def get_flat_params_from(model):
    params = []
    for param in model.parameters():
        params.append(param.data.view(-1))
    flat_params = torch.cat(params)
    return flat_params


def set_flat_params_to(model, flat_params):
    prev_ind = 0
    for param in model.parameters():
        flat_size = int(np.prod(list(param.size())))
        param.data.copy_(flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
        prev_ind += flat_size


def get_flat_grads_from(model):
    grads = []
    for param in model.parameters():
        grads.append(param.grad.data.view(-1))
    flat_grads = torch.cat(grads)
    return flat_grads.numpy()


def print_score(episode_durations):
    durations_t = torch.FloatTensor(episode_durations)
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        print(means.numpy()[-1])
        means.numpy()[-1]


def discount(x, gamma):
    return scipy.signal.lfilter([1.0], [1.0, -gamma], x[::-1])[::-1]


def get_cumu_discounted_rewards(rewards, gamma):
    cumu_rews = []
    for epi_rewards in rewards:
        cumu_rews.append(discount(epi_rewards, gamma)[0])
    return cumu_rews


def calculate_scores(rewards, gamma):
    scores = []
    for epi_rewards in rewards:
        temp = list(discount(epi_rewards, gamma))
        for i in range(len(temp)):
            temp[i] *= gamma**i
        scores.append(temp)
    return scores


def process(data, length=1):
    res = []
    for epi_data in data:
        res += epi_data
    res = np.array(res)
    res = np.reshape(res, [np.shape(res)[0], length])
    return res


def get_features(states, actions, feature):
    features = []
    l = len(states)
    for i in range(l):
        epi_features = []
        epi_states = states[i]
        epi_actions = actions[i]
        m = len(states[i])
        for j in range(m):
            epi_features.append(feature(epi_states[j], epi_actions[j]))
        features.append(epi_features)
    return features

