[PPO모델] ML Agent 를 사용한 Flappy bird 교육
짤막 상식 : Rigidbody의 Collision Detection 설정
리지드 바디를 Countinuous 로 바꿀시 연산량은 올라가지만 빠른 속독의 물체와의 충돌도 감지할수 있다.
PPO 모델을 사용한 Flappy Bird 훈련이다.
리워드는 아래와 같은 기준으로 주어진다.
천장에 닿을경우 : -1
바닥에 닿을경우 : -1
파이프에 닿을경우 : -1
살아있을경우 : +0.1
파이프를 통과하였을경우 : +1
이렇게 보상을 준비하고 훈련시켜 보았다.
우선 보상의 조건을 적게 주어도 살아있는동안 보상을 주면 꾸준히 Mean Reward 가 증가하는 것을 볼수 있었고 이걸 좀 더 빠르게 하기 위해 파이프를 통과하였을 경우의 보상을 주었다.
코드는 아래와 같다.
Env 스크립트
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using TMPro;
public class Env : MonoBehaviour
{
[SerializeField]
private Bird_Agent bird;
[SerializeField]
private TMP_Text txtReward;
public SpriteRenderer spriteRenderer;
public Spawner generator;
private Color color;
private void Start()
{
this.color = this.spriteRenderer.material.color;
//백그라운드에서 재생
Application.runInBackground = true;
this.bird.episodeBeginAction = () =>
{
//버드 위치 초기화
this.bird.transform.localPosition = Vector3.zero;
this.generator.StartGenerate();
};
this.bird.episodeEndAction = () =>
{
this.generator.StopGenerate();
this.generator.Clear();
this.StartCoroutine(this.ChageColor());
};
}
private IEnumerator ChageColor()
{
float time = 0;
this.spriteRenderer.material.color = Color.red;
while (true)
{
time += Time.deltaTime;
if (time > 1f) break;
yield return null;
}
this.spriteRenderer.material.color = this.color;
}
private void Update()
{
this.txtReward.text = bird.GetCumulativeReward().ToString("0.00");
}
}
Bird Agent 스크립트
using System;
using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
using UnityEngine.Events;
public class Bird_Agent : Agent
{
public UnityAction episodeBeginAction;
public UnityAction episodeEndAction;
private Rigidbody2D rb;
[SerializeField]
private float forceMultiplier;
public override void Initialize()
{
//하는게 음슴 ( 그냥 두거나 안해도 될듯 )
//base.Initialize();
this.rb = this.GetComponent<Rigidbody2D>();
}
public override void OnEpisodeBegin()
{
//이벤트 전송
this.episodeBeginAction();
}
public override void CollectObservations(VectorSensor sensor)
{
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
var discreteActions = actionBuffers.DiscreteActions;
var jumpAction = discreteActions[0];
//Debug.LogFormat("<color=red>{0}</color>",jumpAction);
if (jumpAction == 1)
{
var force = Vector2.up * this.forceMultiplier;
this.rb.AddForce(force);
this.rb.velocity = Vector2.zero;
}
//if (this.MaxStep > 0) AddReward(-1f / this.MaxStep); // 기본적으로 0.0002 만큼 패널티를 주겠다.
//적절한 보상 (하늘에 떠있기만 해도 보상)
this.AddReward(0.1f);
}
private bool isJump;
private bool doingJump;
public override void Heuristic(in ActionBuffers actionsOut)
{
var v = Input.GetAxisRaw("Vertical");
this.doingJump = false;
if (v == 1 && this.isJump == false)
{
this.isJump = true;
this.doingJump = true;
}
if (v == 0)
{
this.isJump = false;
}
if (this.doingJump)
{
actionsOut.DiscreteActions.Array[0] = Convert.ToInt32(this.doingJump);
}
}
private void OnCollisionEnter2D(Collision2D collision)
{
if(collision.gameObject.CompareTag("Floor") || collision.gameObject.CompareTag("Ceiling") || collision.gameObject.CompareTag("Wall"))
{
// 패널티
this.AddReward(-1);
this.episodeEndAction();
//에피소드 종료
this.EndEpisode();
}
if (collision.gameObject.CompareTag("Score"))
{
this.AddReward(0.5f);
}
}
}
Spawner(장애물 생성) 스크립트
using System.Collections;
using System.Collections.Generic;
using System.IO.Pipes;
using UnityEngine;
public class Spawner : MonoBehaviour
{
[SerializeField] GameObject wall;
[SerializeField] float minY, maxY, rndY;
[SerializeField] Transform startPos;
[SerializeField] Env gm;
public float time;
bool canIncrease = true;
Coroutine coroutine;
// Start is called before the first frame update
public void StartGenerate()
{
if (this.coroutine != null)
{
this.StopCoroutine(this.coroutine);
}
this.coroutine = StartCoroutine(Spawning());
}
public void StopGenerate()
{
if(this.coroutine != null)
{
this.StopCoroutine(this.coroutine);
}
}
List<GameObject> pipList = new List<GameObject>();
public void Clear()
{
for(int i =0; i< pipList.Count; i++)
{
Destroy(pipList[i]);
}
this.pipList.Clear();
}
void CheckIncrease()
{
if (canIncrease == true)
{
canIncrease = false;
IncreaseTime();
}
else if (canIncrease == true)
{
canIncrease = false;
IncreaseTime();
}
else if (canIncrease == true)
{
canIncrease = false;
IncreaseTime();
}
else if (canIncrease == true)
{
canIncrease = false;
IncreaseTime();
}
else if (canIncrease == true)
{
canIncrease = false;
IncreaseTime();
}
}
void IncreaseTime()
{
if (time >= 1.2f)
{
time -= 0.2f;
StartCoroutine(ResetIncrease());
}
}
IEnumerator ResetIncrease()
{
yield return new WaitForSeconds(2f);
canIncrease = true;
StopCoroutine(ResetIncrease());
}
IEnumerator Spawning()
{
while (true)
{
yield return new WaitForSeconds(time);
rndY = Random.Range(minY, maxY);
var startPosiosion = this.startPos.transform.localPosition;
startPosiosion.y = rndY;
GameObject go = Instantiate(wall,this.gm.transform);
go.transform.localPosition = startPosiosion;
this.pipList.Add(go);
}
}
}
추가적으로 Decision Requester의 period 값을 1로 설정시 FIxed Update와 동일한 주기로 호출이 가능하지만 그만큼 과부하를 내기 때문에 조심하여야 한다.
Ps. 성과가 좋지 못해 파이프나 바닥 천장에 부딪치게 되면 감점을 -10 점 정도로 쎄게 줘보니 이전보다 좀 더 똑똑해 진걸 볼수 있었다. 하지만 감점이 너무 많았는지 평균적인 Mean Reward가 굉장히 낮았고 점수의 증가폭은 오히려 떨어지는 듯 보이기도....
결과Gif (주인을 잘못 만나 교육을 제대로 받지 못한듯 하다 ㅎㅎ;;)