#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Aug 30 14:06:34 2024

@author: XXXX
"""

import numpy as np
import torch
import gymnasium
import pygame

from gymnasium.envs.box2d.car_racing import CarRacing

# Overwrite some globals
SCALE = 6.0  # Track scale
TRACK_RAD = 900 / SCALE  # Track is heavily morphed circle with this radius
PLAYFIELD = 2000 / SCALE  # Game over boundary
FPS = 50  # Frames per second
ZOOM = 2.7  # Camera zoom
TRACK_DETAIL_STEP = 21 / SCALE
TRACK_TURN_RATE = 0.31
TRACK_WIDTH = 20 / SCALE
BORDER = 8 / SCALE
BORDER_MIN_COUNT = 4
GRASS_DIM = PLAYFIELD / 20.0
MAX_SHAPE_DIM = (
    max(GRASS_DIM, TRACK_WIDTH, TRACK_DETAIL_STEP) * np.sqrt(2) * ZOOM * SCALE
)

# Create a new car racer class that essentially copies the gym racing environment
class TrackRacing(CarRacing):
    # This will be like the original racer except
    # 1. The state space will be [track coordinates, car properties],
    #   instead of visual input. This will change __init__ and step
    # 2. The track will be a series of [hard left, soft left, straight, hard right, soft right],
    #   instead of distorted loop. This will change _create_track and reset
    def __init__(self, render_mode=None, verbose=False, lap_complete_percent=0.95, 
                 domain_randomize=False, continuous=True,
                 task_args=None):        
        # First inherit initialisation
        super().__init__(render_mode, verbose, lap_complete_percent, domain_randomize, continuous)                
        
        # Initialise task arguments
        self.set_parameters(task_args)
        # Initialise operations: different road segments
        self.set_operations()        
        # Initialise tasks
        self.set_tasks()
        
        # But I will change the observation space from just an image to relevant variables
        # 1. n_checkpoints x 2 track coordinates
        # 2. 2 car coordinates
        # 3. 2 car velocities
        # 5. 1 car angle
        # 4. 4 wheel angular velocity
        # 5. 1 wheel angle
        self.observation_space = gymnasium.spaces.Dict({
            'track': gymnasium.spaces.Box(-PLAYFIELD, PLAYFIELD, (self.n_contexts, 2), dtype=float),
            'pos': gymnasium.spaces.Box(-PLAYFIELD, PLAYFIELD, (1,2), dtype=float),
            'vel': gymnasium.spaces.Box(-100, 100, (1,2), dtype=float),
            'dir': gymnasium.spaces.Box(0, 2*np.pi, (1,2), dtype=float),
            'wheels': gymnasium.spaces.Box(0, 200, (1,2), dtype=float),
            'steer': gymnasium.spaces.Box(0, 2*np.pi, (1,2), dtype=float)})
        
    def set_parameters(self, args=None):
        # Set input defaults
        args = self.default_parameters() if args is None else args
        # Get relevant parameters from input dictionary
        self.do_test = args['do_test']
        self.n_tasks = args['n_tasks']
        self.n_contexts = args['n_contexts']
        self.n_steps = args['n_steps']
        self.task_train = args['task_train']
        self.task_start = int(self.task_train*self.n_tasks)*self.do_test
        self.task_stop = self.n_tasks - int((1-self.task_train)*self.n_tasks)*(1-self.do_test)    
        self.task_sampling = args['task_sampling']
            
    def default_parameters(self):
        return {'do_test': False, 'n_tasks': 30, 'n_contexts':10, 'n_steps':200, 'task_train': 0.6, 'task_sampling': 0}
            
    def set_operations(self):
        # Define track operations: name, min angle, max angle, min half-length, max half-length
        self.operations = [['hl', np.pi/8*5, np.pi/8*7, 0.75, 1], # hard left
                         ['sl', np.pi/8*1, np.pi/8*3, 0.75, 1], # soft left
                         ['gs', 0, 0, 0.25, 0.75], # straight
                         ['sr', -np.pi/8*3, -np.pi/8*1, 0.75, 1], # hard right
                         ['hr', -np.pi/8*7, -np.pi/8*5, 0.75, 1] # soft right                       
                        ]
        self.n_operations = len(self.operations)        
        # Set transition matrix - relies on dictionaries being ordered (by insertion)
        if self.task_sampling == 0:
            # Almost uniform transitions, but avoid two hard corners in same direction,
            # And slightly downweight doing the same thing twice
            self.transitions = [[0, 0.25, 0.25, 0.25, 0.25],
                                [0.225, 0.1, 0.225, 0.225, 0.225],
                                [0.225, 0.225, 0.1, 0.225, 0.225],
                                [0.225, 0.225, 0.225, 0.1, 0.225],
                                [0.25, 0.25, 0.25, 0.25, 0]]
        else:
            # Less uniform: a common (2/3) and a rare (1/3) transition
            self.transitions = [[0, 1/3, 0, 0, 2/3],
                                [0, 0, 0, 2/3, 1/3],
                                [2/3, 0, 1/3, 0, 0],
                                [0, 2/3, 0, 1/3, 0],
                                [1/3, 0, 2/3, 0, 0]]
    
    def set_tasks(self):
        # Create tasks
        self.tasks = self.get_tasks()
        # Get the task operations for easier processing later
        self.task_names = ['-'.join([c[0] for c in t]) for t in self.tasks]
        self.task_ops = np.array([[self.operations.index(c) for c in t] for t in self.tasks])        
        # Predefine task vectors
        self.task_vec = torch.eye(self.n_tasks, dtype=torch.float)  
    
    # Create tasks based on input arguments
    def get_tasks(self):
        # Fix the random seed so you get the same tasks across training
        np.random.seed(0)
        # Sample tasks from structured context transition probabilities
        tasks = []
        for t in range(self.n_tasks):
            contexts = [np.random.choice(len(self.operations))]
            for c in range(self.n_contexts-1):
                contexts.append(np.random.choice(
                    len(self.operations),p=self.transitions[contexts[-1]]))
            tasks.append([self.operations[c] for c in contexts])
        return tasks

    
    def _create_track(self, task=None):
        # If task not provided: sample randomly from included tasks
        task = np.random.randint(self.task_start, self.task_stop) if task is None else task
        # Sample angles and lengths for current task
        any_intersect=True
        while any_intersect:
            checkpoints = [[0, 0, np.random.uniform(0, 2*np.pi)]] # x, y, direction
            for op in self.tasks[task]:
                # Sample half-length and angle
                half_length = np.random.uniform(op[3], op[4]) * PLAYFIELD/self.n_contexts
                angle = np.random.uniform(op[1], op[2])
                # Add first part of segment: continuing in the same direction
                checkpoints.append([checkpoints[-1][0] + half_length*np.cos(checkpoints[-1][2]),
                                   checkpoints[-1][1] + half_length*np.sin(checkpoints[-1][2]),
                                   checkpoints[-1][2] + angle])
                # Add second part of segment: taking the turn
                checkpoints.append([checkpoints[-1][0] + half_length*np.cos(checkpoints[-1][2]),
                                   checkpoints[-1][1] + half_length*np.sin(checkpoints[-1][2]),
                                   checkpoints[-1][2]])
            # Now create track points along checkpoints. 
            checkpoints = np.array(checkpoints)
            # Check if any pairs of lines cross; try again if so
            any_intersect=False
            for cp_from in range(len(checkpoints)-3):
                for cp_to in range(cp_from + 2, len(checkpoints)-1):
                    if intersect(checkpoints[cp_from][:2], checkpoints[cp_from+1][:2],
                                 checkpoints[cp_to][:2], checkpoints[cp_to+1][:2]):
                        any_intersect=True
                        print('Discarded track because of intersect')
        # Then build the track: equally spaces points along the road
        # First, get the distance accumulated along checkpoints
        distance = np.concatenate([np.zeros(1),np.cumsum(
            np.sqrt(np.sum(np.diff(checkpoints[:,:2], axis=0)**2,axis=1)),axis=0)])
        # Then interpolate x and y along the path
        track = np.stack([
            np.interp(np.linspace(distance[0], distance[-1], self.n_steps+1)[:-1], distance, x) 
            for x in checkpoints[:,:2].transpose()])
        # Set smoothing kernel size
        smooth_size = int(self.n_steps/len(checkpoints))
        # Then pad start and end of track to avoid zero-padding effects
        track = np.concatenate([np.ones((2, smooth_size))*track[:,0][:, None],
                                track,
                                np.ones((2, smooth_size))*track[:,-1][:, None]],axis=1)
        # And smooth track by convolving, then removing padding again
        track = np.stack([np.convolve(x, np.ones(smooth_size)/smooth_size, mode='same')[smooth_size:-smooth_size]
                          for x in track]).transpose()
        # Now calculate orthogonal angle for each piece. Get vector of difference
        beta = np.diff(track, axis=0)
        beta = np.concatenate([beta, beta[-1,:][None,:]], axis=0)
        # Then calculate the angle of the vector of differnce and subtract 90 deg
        beta = np.arctan2(beta[:,1], beta[:,0]) - np.pi/2
        
        # Stick betas to track. Also include a 'percentage complete' signal
        track = np.concatenate([np.linspace(0,1, self.n_steps)[:,None], 
                                beta[:,None], track], axis=1)
        
        # Then build the actual pieces of road. This is copied from AI gym
        # (I couldn't do it in a nicer way with inheritance because it's one big function)
        self.road = []
        for i in range(len(track)-1):
            alpha1, beta1, x1, y1 = track[i + 1]
            alpha2, beta2, x2, y2 = track[i]
            road1_l = (
                x1 - TRACK_WIDTH * np.cos(beta1),
                y1 - TRACK_WIDTH * np.sin(beta1),
            )
            road1_r = (
                x1 + TRACK_WIDTH * np.cos(beta1),
                y1 + TRACK_WIDTH * np.sin(beta1),
            )
            road2_l = (
                x2 - TRACK_WIDTH * np.cos(beta2),
                y2 - TRACK_WIDTH * np.sin(beta2),
            )
            road2_r = (
                x2 + TRACK_WIDTH * np.cos(beta2),
                y2 + TRACK_WIDTH * np.sin(beta2),
            )
            vertices = [road1_l, road1_r, road2_r, road2_l]
            self.fd_tile.shape.vertices = vertices
            t = self.world.CreateStaticBody(fixtures=self.fd_tile)
            t.userData = t
            c = 0.01 * (i % 3) * 255
            t.color = self.road_color + c
            t.road_visited = False
            t.road_friction = 1.0
            t.idx = i
            t.fixtures[0].sensor = True
            self.road_poly.append(([road1_l, road1_r, road2_r, road2_l], t.color))
            self.road.append(t)
        self.track = track
        from matplotlib import pyplot as plt
        plt.figure();
        plt.plot(*checkpoints[:,:2].transpose(), 'x-')
        plt.plot(*track[:,2:].transpose(), 'o')
        plt.plot(track[:,2] + np.cos(track[:,1]), track[:,3] + np.sin(track[:,1]),'k.')
        plt.plot(track[:,2] - np.cos(track[:,1]), track[:,3] - np.sin(track[:,1]),'k.')
        return True

# Quick-and-dirty check for line intersection (https://stackoverflow.com/a/9997374)
def ccw(A,B,C):
    return (C[1]-A[1]) * (B[0]-A[0]) > (B[1]-A[1]) * (C[0]-A[0])

# Return true if line segments AB and CD intersect
def intersect(A,B,C,D):
    return ccw(A,C,D) != ccw(B,C,D) and ccw(A,B,C) != ccw(A,B,D)

# Human-playable game. Copied from gym
def play():
    a = np.array([0.0, 0.0, 0.0])
    global quit, restart

    def register_input():
        global quit, restart
        for event in pygame.event.get():
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_LEFT:
                    a[0] = -1.0
                if event.key == pygame.K_RIGHT:
                    a[0] = +1.0
                if event.key == pygame.K_UP:
                    a[1] = +0.3#+1.0
                if event.key == pygame.K_DOWN:
                    a[2] = +0.2#+0.8  # set 1.0 for wheels to block to zero rotation
                if event.key == pygame.K_RETURN:
                    restart = True
                if event.key == pygame.K_ESCAPE:
                    quit = True

            if event.type == pygame.KEYUP:
                if event.key == pygame.K_LEFT:
                    a[0] = 0
                if event.key == pygame.K_RIGHT:
                    a[0] = 0
                if event.key == pygame.K_UP:
                    a[1] = 0
                if event.key == pygame.K_DOWN:
                    a[2] = 0

            if event.type == pygame.QUIT:
                quit = True

    env = TrackRacing(render_mode="human")

    quit = False
    while not quit:
        env.reset()
        total_reward = 0.0
        steps = 0
        restart = False
        while True:
            register_input()
            s, r, terminated, truncated, info = env.step(a)
            total_reward += r
            if steps % 200 == 0 or terminated or truncated:
                print("\naction " + str([f"{x:+0.2f}" for x in a]))
                print(f"step {steps} total_reward {total_reward:+0.2f}")
            steps += 1
            if terminated or truncated or restart or quit:
                break
    env.close()
