게임 인공지능 프로그래밍/수업내용

[PPO모델] ML Agent 를 사용한 Flappy bird 교육

Bueong_E 2023. 3. 30. 17:20
반응형
SMALL

 

레이를 쏘아 장애물을 탐지하는 Flappy Bird

짤막 상식 : Rigidbody의 Collision Detection 설정

리지드 바디를 Countinuous 로 바꿀시 연산량은 올라가지만 빠른 속독의 물체와의 충돌도 감지할수 있다.


PPO 모델을 사용한 Flappy Bird  훈련이다.

리워드는 아래와 같은 기준으로 주어진다.

 

천장에 닿을경우 : -1

바닥에 닿을경우 : -1

파이프에 닿을경우 : -1

살아있을경우 : +0.1

파이프를 통과하였을경우 : +1

 

이렇게 보상을 준비하고 훈련시켜 보았다.

우선 보상의 조건을 적게 주어도 살아있는동안 보상을 주면 꾸준히 Mean Reward 가 증가하는 것을 볼수 있었고 이걸 좀 더 빠르게 하기 위해 파이프를 통과하였을 경우의 보상을 주었다.

파이프 사이의 공간에 Collision  을 주고 지나갈시 점수를 주었다.
열심히 훈련중인 Flappy Bird 들

코드는 아래와 같다.

 

Env 스크립트 

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using TMPro;

public class Env : MonoBehaviour
{
    [SerializeField]
    private Bird_Agent bird;
    [SerializeField]
    private TMP_Text txtReward;

    public SpriteRenderer spriteRenderer;
    public Spawner generator;

    private Color color;
    private void Start()
    {
        this.color = this.spriteRenderer.material.color;
        //백그라운드에서 재생
        Application.runInBackground = true;
        this.bird.episodeBeginAction = () =>
        {
            //버드 위치 초기화
            this.bird.transform.localPosition = Vector3.zero;
            this.generator.StartGenerate();

        };
        this.bird.episodeEndAction = () =>
        {
            this.generator.StopGenerate();
            this.generator.Clear();
            this.StartCoroutine(this.ChageColor());
        };
    }

    private IEnumerator ChageColor()
    {
        float time = 0;
        this.spriteRenderer.material.color = Color.red;
        while (true) 
        { 
            time += Time.deltaTime;
            if (time > 1f) break;
            yield return null;
        }
        this.spriteRenderer.material.color = this.color;
    }

    private void Update() 
    {
        this.txtReward.text = bird.GetCumulativeReward().ToString("0.00");
    }


}

Bird Agent 스크립트

using System;
using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
using UnityEngine.Events;


public class Bird_Agent : Agent
{
    public UnityAction episodeBeginAction;
    public UnityAction episodeEndAction;
    private Rigidbody2D rb;
    
    [SerializeField]
    private float forceMultiplier;


    public override void Initialize()
    {
        //하는게 음슴 ( 그냥 두거나 안해도 될듯 )
        //base.Initialize();
        this.rb = this.GetComponent<Rigidbody2D>();   
    }

    public override void OnEpisodeBegin()
    {
        //이벤트 전송
        this.episodeBeginAction();
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        

    }

    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        var discreteActions = actionBuffers.DiscreteActions;
        var jumpAction = discreteActions[0];
        //Debug.LogFormat("<color=red>{0}</color>",jumpAction);

        if (jumpAction == 1)
        {
            var force = Vector2.up * this.forceMultiplier;
            this.rb.AddForce(force);
            this.rb.velocity = Vector2.zero;
        }

        //if (this.MaxStep > 0) AddReward(-1f / this.MaxStep); // 기본적으로 0.0002 만큼 패널티를 주겠다.
        //적절한 보상 (하늘에 떠있기만 해도 보상)
        this.AddReward(0.1f);
    }

    private bool isJump;
    private bool doingJump;

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var v = Input.GetAxisRaw("Vertical");

        this.doingJump = false;

        if (v == 1 && this.isJump == false)
        {
            this.isJump = true;
            this.doingJump = true;
        }
        if (v == 0)
        {
            this.isJump = false;
        }
        if (this.doingJump)
        {
            actionsOut.DiscreteActions.Array[0] = Convert.ToInt32(this.doingJump);
        }        
    }


    private void OnCollisionEnter2D(Collision2D collision)
    {
        if(collision.gameObject.CompareTag("Floor") || collision.gameObject.CompareTag("Ceiling") || collision.gameObject.CompareTag("Wall"))
        {
            // 패널티
            this.AddReward(-1);
            this.episodeEndAction();
            //에피소드 종료
            this.EndEpisode();           
        }
        if (collision.gameObject.CompareTag("Score"))
        {
            this.AddReward(0.5f);
        }
    }
}

Spawner(장애물 생성) 스크립트

using System.Collections;
using System.Collections.Generic;
using System.IO.Pipes;
using UnityEngine;


public class Spawner : MonoBehaviour
{

    [SerializeField] GameObject wall;
    [SerializeField] float  minY, maxY, rndY;
    [SerializeField] Transform startPos;

    [SerializeField] Env gm;

    public float time;
    bool canIncrease = true;

    Coroutine coroutine;

    // Start is called before the first frame update
    

    public void StartGenerate()
    {
        if (this.coroutine != null)
        {
            this.StopCoroutine(this.coroutine);
        }
        this.coroutine = StartCoroutine(Spawning());
    }


    public void StopGenerate()
    {
        if(this.coroutine != null)
        {
            this.StopCoroutine(this.coroutine);
        }
    }

    List<GameObject> pipList = new List<GameObject>(); 

    public void Clear()
    {
        for(int i =0; i< pipList.Count; i++)
        {
            Destroy(pipList[i]);
        }
        this.pipList.Clear();    
    }

    
    void CheckIncrease()
    {
        if (canIncrease == true)
        {
            canIncrease = false;
            IncreaseTime();
        }
        else if (canIncrease == true)
        {
            canIncrease = false;
            IncreaseTime();
        }
        else if (canIncrease == true)
        {
            canIncrease = false;
            IncreaseTime();
        }
        else if (canIncrease == true)
        {
            canIncrease = false;
            IncreaseTime();
        }
        else if (canIncrease == true)
        {
            canIncrease = false;
            IncreaseTime();
        }
    }

    void IncreaseTime()
    {
        if (time >= 1.2f)
        {
            time -= 0.2f;
            StartCoroutine(ResetIncrease());
        }
        
    }

    IEnumerator ResetIncrease()
    {
        yield return new WaitForSeconds(2f);
        canIncrease = true;
        StopCoroutine(ResetIncrease());
    }

    IEnumerator Spawning()
    {
        while (true)
        {
           
            yield return new WaitForSeconds(time);
            rndY = Random.Range(minY, maxY);

            var startPosiosion = this.startPos.transform.localPosition;
            startPosiosion.y = rndY;

            GameObject go = Instantiate(wall,this.gm.transform);
            go.transform.localPosition = startPosiosion;

            this.pipList.Add(go);   
        }
    }
}

추가적으로 Decision Requester의 period 값을 1로 설정시 FIxed Update와 동일한 주기로 호출이 가능하지만 그만큼 과부하를 내기 때문에 조심하여야 한다.

Default 값은 5 이다.

Ps. 성과가 좋지 못해 파이프나 바닥 천장에 부딪치게 되면 감점을 -10 점 정도로 쎄게 줘보니 이전보다 좀 더 똑똑해 진걸 볼수 있었다. 하지만 감점이 너무 많았는지 평균적인 Mean Reward가 굉장히 낮았고  점수의 증가폭은 오히려 떨어지는 듯 보이기도....

 

결과Gif (주인을 잘못 만나 교육을 제대로 받지 못한듯 하다 ㅎㅎ;;)

반응형
LIST