using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.UI;//
using UnityEngine.Tilemaps;//
using UnityEngine.SceneManagement;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class RLAgent : Agent
{
    public int genType = 0;
    public LevelGenerator mLGen;
    public float speed; //speed of character
    public float jumpForce; //how highhhh
    public int direction;

    private Rigidbody2D rb;
    private bool facingRight = true;
    private Vector2 spawnPoint; // set at construct

    private bool isGrounded;
    public Transform groundCheck;
    public float checkRadius;
    public LayerMask whatIsGround;
    private TilemapCollider2D tilemapCollider;
    public Tilemap tilemap;
    public Animator anim;

    private List<ContactPoint2D> contactPoints;
    private SpriteRenderer spriteRenderer;
    private int extraJumps;
    public int extraJumpsValue;
    public float deadFlashingTimer;

    public int attempts;
    public Text attemptTxt;

    public string levelToLoad;

    public int currLevel;

    float distanceToEnd;
    float bestDistance;
    public StatsRecorder statsRecorder;


    #region
    [SerializeField] GameObject attemptLimitUI, StartTextUI;
    LevelLayoyutBehaviour LLB;
    [SerializeField] int gameLevel;
    [HideInInspector] public int levelCounter = 0;
    [SerializeField] Text cue;
    #endregion

    #region input tracker
    DataHandler dH;
    int keycountA = 0;
    int keycountD = 0;
    int keycountSpace = 0;
    float keycountA_f = 0f;
    int lineCounter = 0; // counts the number of line
    int index = 0;
    Dictionary<KeyCode, KeyData> keys;
    #endregion
    [SerializeField] int attemptLimits;

    [SerializeField] string currentLevel;

    void Start()
    {
        if (genType == 0)
        {
            mLGen = FindObjectOfType<MLLayoutGen>();
        }
        else if (genType == 1)
        {
            mLGen = FindObjectOfType<PERMLayoutGen>();
        }

        LLB = FindObjectOfType<LevelLayoyutBehaviour>();
        dH = FindObjectOfType<DataHandler>();
        deadFlashingTimer = 0f;
        extraJumps = extraJumpsValue;
        tilemapCollider = tilemap.GetComponent<TilemapCollider2D>();
        rb = GetComponent<Rigidbody2D>();
        spawnPoint = rb.position;
        contactPoints = new List<ContactPoint2D>();
        spriteRenderer = GetComponent<SpriteRenderer>();
        attempts = 0;
        attemptTxt.text = attempts.ToString();
        StartCoroutine(FlashStart());
        statsRecorder = Academy.Instance.StatsRecorder;
        mLGen.GenerateLevel();
        bestDistance = Mathf.Infinity;
        //PlayerPrefs.SetString("NextLevel", levelToLoad);

        #region key input dictionary
        keys = new Dictionary<KeyCode, KeyData>();
        // have to manually add all the keys that user needs to press
        keys.Add(KeyCode.A, new KeyData());
        keys.Add(KeyCode.D, new KeyData());
        keys.Add(KeyCode.LeftArrow, new KeyData());
        keys.Add(KeyCode.RightArrow, new KeyData());
        keys.Add(KeyCode.Space, new KeyData());
        InvokeRepeating(nameof(SendKeyInputs), 0f, 5f);
        #endregion
    }
    public override void OnEpisodeBegin()
    {
        rb.position = spawnPoint;
        bestDistance = Mathf.Infinity;
        attemptTxt.text = attempts.ToString();
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        // What about distance to goal?
        distanceToEnd = Vector2.Distance(transform.position, mLGen.goalPosition); // calculate distance to goal and return a float value
        sensor.AddObservation(distanceToEnd);
    }

    public override void OnActionReceived(ActionBuffers actionBuffer)
    {
        isGrounded = Physics2D.OverlapCircle(groundCheck.position, checkRadius, whatIsGround);
        if (deadFlashingTimer > 0f)
        {
            spriteRenderer.enabled = (Mathf.RoundToInt(deadFlashingTimer * 1000f) % 100) > 50; // blink 10 times per second
            deadFlashingTimer -= Time.deltaTime;
            if (deadFlashingTimer <= 0f)
            {
                spriteRenderer.enabled = true;
            }
        }
        //Vector2 hitPosition = Vector2.zero;
        //// walk collisions and see if the ground is ever in there
        //contactPoints.Clear();
        //tilemapCollider.GetContacts(contactPoints);
        //foreach (ContactPoint2D contactPoint in contactPoints)
        //{
        //    hitPosition.x = contactPoint.point.x + (0.1f * contactPoint.normal.x);
        //    hitPosition.y = contactPoint.point.y + (0.1f * contactPoint.normal.y);
        //    Vector3Int cellPosition = tilemap.WorldToCell(hitPosition);
        //    //debug.transform.position = new Vector3(hitPosition.x, hitPosition.y, 0);  //tilemap.CellToWorld(cellPosition);
        //    TileBase t = tilemap.GetTile(cellPosition);
        //    if (t != null)
        //    {
        //        if (tilemap.GetSprite(cellPosition).name.IndexOf("spike") > -1)
        //        {
        //            AddReward(-1.0f);
        //            Die();
        //            EndEpisode();
        //            break;
        //        }
        //    }
        //}
        //check if the action is left, right, or jump
        // left = 0, right = 1, jump = 2
        var moveInput = actionBuffer.DiscreteActions[0];
        //if (moveInput != 3) // old if (moveInput == 0 || moveInput = 1) ****
        //{         
        //    MoveAgent(moveInput);
        // }
        //if (moveInput == 3)
        //{          
        //     JumpAgent();
        //}

        if (actions.HasFlag(Actioninput.Left))
            MoveAgent(1);

        if (actions.HasFlag(Actioninput.Right))
            MoveAgent(2);

        if (actions.HasFlag(Actioninput.Jump))
            JumpAgent();

        if (!actions.HasFlag(Actioninput.Right) && !actions.HasFlag(Actioninput.Left))
            MoveAgent(0);

        AddReward(-0.01f);
        distanceToEnd = Vector2.Distance(transform.position, mLGen.goalPosition); // calculate distance to goal and return a float value
        if (distanceToEnd < bestDistance)
        {
            AddReward(1.0f); // reward if agent is nearer to goal
            bestDistance = distanceToEnd;
        }
        DetectKeyPress();
        //if (StepCount == MaxStep)
        //{
        // Die();
        //}

        ResetActions();
    }

    void MoveAgent(int moveInput)
    {
        if (moveInput == 1)
        {
            direction = -1;
        }
        else if (moveInput == 2)
        {
            direction = 1;
        }
        else
        {
            direction = 0;
        }
        //unity built in fx; moveinput = 1 for right, moveinput = -1 for left
        rb.velocity = new Vector2(direction * speed, rb.velocity.y);
        anim.SetFloat("Speed", Mathf.Abs(direction));

        if (facingRight == false && direction > 0)
        {
            Flip();
        }
        else if (facingRight == true && direction < 0)
        {
            Flip();
        }

        if (direction > 0)
        {
            anim.Play("Run");
        }
        else if (direction < 0)
        {
            anim.Play("Run");
        }
    }
    void JumpAgent()
    {
        if (isGrounded == true)
        {
            extraJumps = extraJumpsValue;
        }

        if (extraJumps > 0)
        {
            rb.velocity = Vector2.up * jumpForce;
            anim.Play("Jump");
            extraJumps--;
        }
        else if (extraJumps == 0 && isGrounded == true)
        {
            rb.velocity = Vector2.up * jumpForce;
            anim.Play("Jump");
        }
    }

    // void FixedUpdate() //to manage physics aspects
    // {
    //     if(deadFlashingTimer > 0f)
    //     {
    //         spriteRenderer.enabled = (Mathf.RoundToInt(deadFlashingTimer * 1000f) % 100) > 50; // blink 10 times per second
    //         deadFlashingTimer -= Time.deltaTime;
    //         if(deadFlashingTimer <= 0f)
    //         {
    //             spriteRenderer.enabled = true;
    //         }
    //     }
    //     isGrounded = Physics2D.OverlapCircle(groundCheck.position, checkRadius, whatIsGround);
    //     Vector2 hitPosition = Vector2.zero;
    //     // walk collisions and see if the ground is ever in there
    //     contactPoints.Clear();
    //     tilemapCollider.GetContacts(contactPoints);
    //     foreach (ContactPoint2D contactPoint in contactPoints)
    //     {
    //         hitPosition.x = contactPoint.point.x + (0.1f * contactPoint.normal.x);
    //         hitPosition.y = contactPoint.point.y + (0.1f * contactPoint.normal.y);
    //         Vector3Int cellPosition = tilemap.WorldToCell(hitPosition);
    //         //debug.transform.position = new Vector3(hitPosition.x, hitPosition.y, 0);  //tilemap.CellToWorld(cellPosition);
    //         TileBase t = tilemap.GetTile(cellPosition);
    //         if(t != null)
    //         {
    //             if(tilemap.GetSprite(cellPosition).name.IndexOf("spike") > -1)
    //             {
    //                 Die();
    //                 break;
    //             }
    //         }
    //     }

    //     moveInput = Input.GetAxis("Horizontal"); //unity built in fx; moveinput = 1 for right, moveinput = -1 for left
    //     rb.velocity = new Vector2(moveInput * speed, rb.velocity.y);
    //     anim.SetFloat("Speed", Mathf.Abs(moveInput));

    //     if (facingRight == false && moveInput > 0)
    //     {
    //         Flip();
    //     } else if (facingRight == true && moveInput < 0)
    //     {
    //         Flip();
    //     }

    //     if (moveInput > 0 && Input.GetKeyDown(KeyCode.RightArrow))
    //     {
    //         anim.Play("Run");
    //     } else if (moveInput < 0 && Input.GetKeyDown(KeyCode.LeftArrow))
    //     {
    //         anim.Play("Run");
    //     }
    // }

    // void Update()
    // {
    //     if (isGrounded == true)
    //     {
    //         extraJumps = extraJumpsValue;
    //     }
    //     if (Input.GetKeyDown(KeyCode.L))
    //     {
    //         StartCoroutine(DataHandler.Post(DataHandler.s_PlayerID + " Total Attempt: " + DataHandler.s_TotalAttempt));
    //     }

    //     if (Input.GetKeyDown(KeyCode.Space) && extraJumps > 0) {
    //         rb.velocity = Vector2.up * jumpForce;
    //         anim.Play("Jump");
    //         extraJumps--;
    //     } else if (Input.GetKeyDown(KeyCode.Space) && extraJumps == 0 && isGrounded == true) {
    //         rb.velocity = Vector2.up * jumpForce;
    //         anim.Play("Jump");
    //     }
    // }

    void Flip()
    {
        facingRight = !facingRight;
        Vector3 Scaler = transform.localScale;
        Scaler.x *= -1;
        transform.localScale = Scaler;
    }

    /// <summary>
    /// Called when player hits the spike.
    /// </summary>
    public void Die()
    {
        AddReward(-10.0f);
        StartCoroutine(StopMovement());
        rb.position = spawnPoint;
        deadFlashingTimer = 0.5f;
        attempts++;
        attemptTxt.text = attempts.ToString();
        DataHandler.s_TotalAttempt += attempts;
        // mLGen.CreateLevelData();
        //statsRecorder.Add("CompletedEpisode", 0.0f);
        if (genType == 1)
        {
            mLGen.recordResponse(GetCumulativeReward());
        }
        if (attempts >= attemptLimits) // Reset the level should player hits attempt limits
        {
            //mLGen.RemoveAllTiles();
            //mLGen.GenerateLevel();
            AttemptLimit();
        }
        statsRecorder.Add("CompletedEpisode", 0.0f);
        statsRecorder.Add("BestDistance", bestDistance);
        EndEpisode();
    }

    void AttemptLimit()// when attempt limit hits
    {
        StartCoroutine(FlashPause(true));
        if (genType == 0)
        {
            mLGen = FindObjectOfType<MLLayoutGen>();
            mLGen.GenerateLevel();
        }
        else if (genType == 1)
        {
            mLGen = FindObjectOfType<PERMLayoutGen>();
            mLGen.GenerateLevel();
        }
        rb.position = spawnPoint;
        attempts = 0;
        //switch (currentLevel)
        //{
        //    case "First Hard":
        //        StartCoroutine(FlashPause(false));
        //        break;
        //    case "Random Level":
        //        StartCoroutine(FlashPause(true));
        //        if (genType == 0)
        //        {
        //            mLGen = FindObjectOfType<MLLayoutGen>();
        //            mLGen.GenerateLevel();
        //        }
        //        else if (genType == 1)
        //        {
        //            mLGen = FindObjectOfType<PERMLayoutGen>();
        //            mLGen.GenerateLevel();
        //        }
        //        rb.position = spawnPoint;
        //        attempts = 0;
        //        break;
        //    case "End Hard":
        //        StartCoroutine(FlashPause(false));
        //        break;
        //}
    }
    public static void RespawnWithNewLevel()
    {
        GameObject.FindObjectOfType<RLAgent>().Die();
        GameObject.FindObjectOfType<RLAgent>().CallFlashStart();
        GameObject.FindObjectOfType<RLAgent>().PlayerResetsLevel();
    }
    public void CallFlashStart()
    {
        StartCoroutine(FlashStart());
    }
    /// <summary>
    /// Contols what happen when you collide with gameobject that is set to trigger
    /// </summary>
    /// <param name="collision"></param>
    private void OnTriggerEnter2D(Collider2D collision)
    {
        Debug.Log(collision);
        if (collision.tag == "Goal")
        {

            switch (gameLevel)
            {
                case 0: // First hard level
                    GetTimeData();
                    // SendKeyInputs();
                    SceneManager.LoadScene(SceneManager.GetActiveScene().buildIndex + 1);
                    break;
                case 1:// for training stage
                    levelCounter++;
                    StartCoroutine(FlashStart());
                    // LLB.GenerateLevel();
                    if (genType == 0)
                    {
                        mLGen = FindObjectOfType<MLLayoutGen>();
                        // mLGen.CreateLevelData();
                        mLGen.RemoveAllTiles();
                        mLGen.GenerateLevel();
                    }
                    else if (genType == 1)
                    {
                        mLGen = FindObjectOfType<PERMLayoutGen>();
                        // mLGen.CreateLevelData();
                        mLGen.RemoveAllTiles();
                        mLGen.GenerateLevel();
                    }
                    rb.position = spawnPoint;
                    GetTimeData();
                    LevelEndSend();
                    attempts = 0;
                    if (levelCounter >= 10) // when the player have completed all 10 level
                    {
                        SceneManager.LoadScene(SceneManager.GetActiveScene().buildIndex + 1);
                    }
                    AddReward(100f);
                    statsRecorder.Add("CompletedEpisode", 1.0f);
                    statsRecorder.Add("BestDistance", 0.0f);
                    if (genType == 1)
                    {
                        mLGen.recordResponse(GetCumulativeReward());
                    }
                    EndEpisode();

                    break;
                case 2: // last hard level
                    GetTimeData();
                    SceneManager.LoadScene("End");
                    break;
            }

            //SceneManager.LoadScene("LevelCompleteSuccess");
            //Debug.Log("Going to next level");
        }
        if (collision.tag == "Spike")
        {
            Die();
        }
    }

    public void PlayerResetsLevel()
    {
        StartCoroutine(DataHandler.Post(DataHandler.s_PlayerID + " Resets level")); // Disable for now as i am shifting the URl
    }
    void GetTimeData()
    {
        StartCoroutine(DataHandler.Post(DataHandler.s_PlayerID + " Total Attempt: "
            + DataHandler.s_TotalAttempt + " Current Level " + currentLevel + gameLevel + "\n" + "Time taken: " + TimeTracker.localTime
            + "\n" + "Total Time: " + TimeTracker.totalTime));
        TimeTracker.localTime = 0;
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var discreteActionsOut = actionsOut.DiscreteActions;
        if (Input.GetKey(KeyCode.D))
        {
            OnState(Actioninput.Right);
            discreteActionsOut[0] = 2;
        }
        if (Input.GetKey(KeyCode.A))
        {
            OnState(Actioninput.Left);
            //Actioninput.Left
            discreteActionsOut[0] = 1;
        }
        if (Input.GetKey(KeyCode.Space))
        {
            OnState(Actioninput.Jump);
            discreteActionsOut[0] = 3;
        }
    }

    [System.Flags]
    public enum Actioninput
    {
        Nothing = 0,
        Left = 1 << 0,
        Right = 1 << 1,
        Jump = 1 << 2
    }

    Actioninput actions;

    void OnState(Actioninput input)
    {
        actions |= input;
    }

    void ResetActions()
    {
        actions = Actioninput.Nothing;
    }

    IEnumerator StopMovement()
    {
        cue.text = "GET READY 3 !";
        speed = 0;
        jumpForce = 0;
        yield return new WaitForSeconds(1);
        cue.text = "GET READY 2 !";
        yield return new WaitForSeconds(1);
        cue.text = "GET READY 1 !";
        yield return new WaitForSeconds(1);
        cue.text = "GO GO GO !!!";
        speed = 3.5f;
        jumpForce = 11;
        yield return new WaitForSeconds(2);
        cue.text = " ";
    }

    IEnumerator FlashPause(bool levelGen)
    {
        if (levelGen)
        {
            if (levelCounter <= 10)
            {
                attemptLimitUI.SetActive(true);
                yield return new WaitForSeconds(5);
                attemptLimitUI.SetActive(false);
                StartCoroutine(StopMovement());
                levelCounter++;
            }
            else
            {
                attemptLimitUI.SetActive(true);
                yield return new WaitForSeconds(5);
                attemptLimitUI.SetActive(false);
                SceneManager.LoadScene(SceneManager.GetActiveScene().buildIndex + 1);
            }
        }
        if(!levelGen)
        {
            attemptLimitUI.SetActive(true);
            yield return new WaitForSeconds(5);
            attemptLimitUI.SetActive(false);
            SceneManager.LoadScene(SceneManager.GetActiveScene().buildIndex + 1);
        }

    }
    IEnumerator FlashStart()
    {
        StartTextUI.SetActive(true);
        yield return new WaitForSeconds(5);
        StartTextUI.SetActive(false);
        StartCoroutine(StopMovement());
    }
    /// <summary>
    /// Detect key press
    /// </summary>
    void DetectKeyPress()
    {
        if (Input.GetKey(KeyCode.A))
        {
            keys[KeyCode.A].tempTime = Time.realtimeSinceStartup;
        }
        if (Input.GetKey(KeyCode.D))
        {
            keys[KeyCode.D].tempTime = Time.realtimeSinceStartup;
        }
        if (Input.GetKey(KeyCode.LeftArrow))
        {
            keys[KeyCode.LeftArrow].tempTime = Time.realtimeSinceStartup;
        }
        if (Input.GetKey(KeyCode.RightArrow))
        {
            keys[KeyCode.RightArrow].tempTime = Time.realtimeSinceStartup;
        }
        if (Input.GetKey(KeyCode.Space))
        {
            keys[KeyCode.Space].tempTime = Time.realtimeSinceStartup;
        }


        if (Input.GetKeyUp(KeyCode.A))
        {
            float t = Time.realtimeSinceStartup;
            ++keys[KeyCode.A].times;
            //keys[KeyCode.A].duration += t - keys[KeyCode.A].tempTime;
            dH.keyInputs += "A key press for " + (t - keys[KeyCode.A].tempTime) + "\n";
            lineCounter++;
            keys[KeyCode.A].tempTime = 0;

        }
        if (Input.GetKeyUp(KeyCode.LeftArrow))
        {
            float t = Time.realtimeSinceStartup;
            ++keys[KeyCode.LeftArrow].times;
            //keys[KeyCode.A].duration += t - keys[KeyCode.A].tempTime;
            dH.keyInputs += "LeftArrow key press for " + (t - keys[KeyCode.LeftArrow].tempTime) + "\n";
            lineCounter++;
            keys[KeyCode.LeftArrow].tempTime = 0;

        }
        if (Input.GetKeyUp(KeyCode.D))
        {
            float t = Time.realtimeSinceStartup;
            ++keys[KeyCode.D].times;
            //keys[KeyCode.D].duration += keys[KeyCode.D].tempTime - t;
            dH.keyInputs += "D key press for " + (t - keys[KeyCode.D].tempTime) + "\n";
            lineCounter++;
            keys[KeyCode.D].tempTime = 0;
        }
        if (Input.GetKeyUp(KeyCode.RightArrow))
        {
            float t = Time.realtimeSinceStartup;
            ++keys[KeyCode.RightArrow].times;
            //keys[KeyCode.D].duration += keys[KeyCode.D].tempTime - t;
            dH.keyInputs += "RightArrow key press for " + (t - keys[KeyCode.RightArrow].tempTime) + "\n";
            lineCounter++;
            keys[KeyCode.RightArrow].tempTime = 0;
        }
        if (Input.GetKeyUp(KeyCode.Space))
        {
            float t = Time.realtimeSinceStartup;
            ++keys[KeyCode.Space].times;
            keys[KeyCode.Space].duration += keys[KeyCode.Space].tempTime - t;
            dH.keyInputs += "Space press for " + (t - keys[KeyCode.Space].tempTime) + "\n";
            lineCounter++;
            keys[KeyCode.Space].tempTime = 0;
        }
    }
    void SendKeyInputs()
    {
        if (lineCounter >= 100)
        {
            StartCoroutine(DataHandler.PostInputs(index + "ML Agent Data set for " + DataHandler.s_PlayerID + " " + dH.keyInputs));
            index++;
            lineCounter = 0;
            dH.keyInputs = "";
        }
    }

    void LevelEndSend()
    {
        StartCoroutine(DataHandler.PostInputs(index + "ML Agent Data set for " + DataHandler.s_PlayerID + " " + dH.keyInputs + " end of level."));
        index++;
        lineCounter = 0;
        dH.keyInputs = "";
    }
}
