Program Listing for File connect2_env.h

Return to documentation for file (src/bitrl/envs/connect2/connect2_env.h)

/*
 * Connect2 environment from
 *  <a href="https://github.com/JoshVarty/AlphaZeroSimple">AlphaZeroSimple</a>
 *
 *
 *
 *
 */
#ifndef CONNECT2_ENV_H
#define CONNECT2_ENV_H

#include "bitrl/bitrl_types.h"
#include "bitrl/envs/time_step.h"
#include "bitrl/envs/env_types.h"
#include "bitrl/envs/env_base.h"

#include <boost/noncopyable.hpp>
#include <vector>
#include <string>
#include <unordered_map>
#include <memory>

namespace bitrl{
namespace envs{
namespace connect2{

class Connect2 final: public EnvBase<TimeStep<std::vector<uint_t>>,
                                     DiscreteVectorStateDiscreteActionEnv<53, 0, 4, uint_t > >
{

public:

    static  const std::string name;

    typedef EnvBase<TimeStep<std::vector<uint_t> >,
                             DiscreteVectorStateDiscreteActionEnv<53, 0, 4, uint_t > > base_type;


    typedef typename base_type::time_step_type time_step_type;

    typedef typename base_type::state_space_type state_space_type;

    typedef typename base_type::action_space_type action_space_type;

    typedef typename base_type::action_type action_type;

    typedef typename base_type::state_type state_type;

    using base_type::reset;

    Connect2();

    explicit Connect2(uint_t cidx);

    Connect2(const Connect2& other);

    virtual void make(const std::string& version,
                      const std::unordered_map<std::string, std::any>& options) override final;

    virtual time_step_type step(const action_type& action)override final;

    virtual void close()override final;

    virtual time_step_type reset(uint_t /*seed*/,
                                 const std::unordered_map<std::string, std::any>& /*options*/)override final;

    Connect2 make_copy(uint_t cidx)const;

    uint_t n_states()const noexcept{ return state_space_type::size; }

    uint_t n_actions()const noexcept{return action_space_type::size;}

    time_step_type move(const uint_t pid, const action_type& action);

    bool is_win(uint_t player)const noexcept;

    bool has_legal_moves()const noexcept;

    std::vector<uint_t> get_valid_moves()const;

private:



    real_t discount_;

    const uint_t player_id_1_{1};

    const uint_t player_id_2_{2};

    const uint_t win_val_{2};

    std::vector<uint_t> board_;

    bool is_finished_{false};

};

inline
void
Connect2::close(){
    board_ = std::vector<uint_t> ();
    this -> invalidate_is_created_flag_();
}

}
}
}

#endif