Program Listing for File connect2_env.cpp¶
↰ Return to documentation for file (src/bitrl/envs/connect2/connect2_env.cpp)
#include "bitrl/envs/connect2/connect2_env.h"
#include "bitrl/envs/time_step.h"
#include <algorithm>
#include <unordered_map>
#include <stdexcept>
#include <any>
#include <memory>
namespace bitrl{
namespace envs{
namespace connect2{
const std::string Connect2::name = "Connect2";
Connect2::Connect2()
:
EnvBase<TimeStep<std::vector<uint_t>>,
DiscreteVectorStateDiscreteActionEnv<53, 0, 4, uint_t > >(0, "Connect2"),
discount_(1.0),
board_()
{}
Connect2::Connect2(uint_t cidx)
:
EnvBase<TimeStep<std::vector<uint_t>>,
DiscreteVectorStateDiscreteActionEnv<53, 0, 4, uint_t > >(cidx, "Connect2"),
discount_(1.0),
board_()
{}
Connect2::Connect2(const Connect2& other)
:
EnvBase<TimeStep<std::vector<uint_t>>,
DiscreteVectorStateDiscreteActionEnv<53, 0, 4, uint_t > >(other),
discount_(1.0),
board_(other.board_),
is_finished_(other.is_finished_)
{}
void
Connect2::make(const std::string& /*version*/,
const std::unordered_map<std::string, std::any>& /*options*/){
board_.resize(4, 0);
this -> set_version_("v1");
this -> make_created_();
}
Connect2::time_step_type
Connect2::step(const action_type& action){
return move(player_id_1_, action);
}
Connect2::time_step_type
Connect2::reset(uint_t /*seed*/,
const std::unordered_map<std::string, std::any>& /*options*/){
board_ = std::vector<uint_t>(4, 0);
is_finished_ = false;
this -> get_current_time_step_() = Connect2::time_step_type(TimeStepTp::FIRST, 0.0, board_, discount_);
return this -> get_current_time_step_();
}
bool
Connect2::is_win(uint_t player)const noexcept{
auto player_sum = 0;
std::for_each(board_.begin(),
board_.end(),
[&player_sum, player](auto val){
if(val == player)
player_sum += 1;
});
return player_sum == win_val_;
}
std::vector<uint_t>
Connect2::get_valid_moves()const{
std::vector<uint_t> val_moves_;
val_moves_.reserve(4);
for(uint_t i=0; i<board_.size(); ++i){
if(board_[i] == 0){
val_moves_.push_back(i);
}
}
return val_moves_;
}
bool
Connect2::has_legal_moves()const noexcept{
for(auto idx : board_){
if(idx == 0){
return true;
}
}
return false;
}
Connect2::time_step_type
Connect2::move(const uint_t pid, const action_type& action){
if(pid != 1 && pid != 2){
throw std::logic_error("Invalid player id: " + std::to_string(pid));
}
if(action >= board_.size()){
throw std::logic_error("Invalid action id: " + std::to_string(action));
}
if(is_finished_){
return reset();
}
auto valid_move = true;
if(board_[action] != 0){
valid_move = false;
}
if(valid_move){
// this position on the board
// is occupied by the given player
board_[action] = pid;
bool won = is_win(pid);
bool has_moves = has_legal_moves();
// there may be more moves to make in the game
// but the player may have won. That's why we look
// at the won variable first
auto step_type = TimeStepTp::INVALID_TYPE;
auto reward = 0.0;
if(won){
step_type = TimeStepTp::LAST;
is_finished_ = true;
reward = 1.0;
}
else if(has_moves){
// the player has not won the game
// and there may be more moves
step_type = TimeStepTp::MID;
reward = 0.0;
}
else{
// the player lost the game
step_type = TimeStepTp::LAST;
is_finished_ = true;
reward = -1.0;
}
auto val_moves = get_valid_moves();
std::unordered_map<std::string, std::any> extra;
extra["valid_moves"] = std::any(val_moves);
return Connect2::time_step_type(step_type, reward,
board_, discount_,
std::move(extra));
}
throw std::logic_error("Move: " + std::to_string(action) + " is invalid");
}
Connect2
Connect2::make_copy(uint_t cidx)const{
Connect2 copy(cidx);
std::unordered_map<std::string, std::any> ops;
auto ver = this -> version();
copy.make(ver, ops);
return copy;
}
}
}
}