Program Listing for File gym_walk.h

Return to documentation for file (src/bitrl/envs/gdrl/gym_walk.h)

/*
 * GymWalk environment from
 * <a href="https://github.com/mimoralea/gym-walk2">gym_walk</a>
 *
 *
 *
 */

#ifndef GYM_WALK_H
#define GYM_WALK_H

#include "bitrl/bitrl_config.h"
#include "bitrl/bitrl_types.h"
#include "bitrl/envs/time_step.h"
#include "bitrl/envs/env_types.h"
#include "bitrl/envs/env_base.h"
#include "bitrl/extern/nlohmann/json/json.hpp"


#include <vector>
#include <tuple>
#include <string>
#include <any>
#include <unordered_map>
#include <memory>

#ifdef BITRL_DEBUG
#include <cassert>
#endif

namespace bitrl{
namespace envs::gdrl
{


    template<uint_t state_size>
    class GymWalk final: public EnvBase<TimeStep<uint_t>,
                                        ScalarDiscreteEnv<state_size, 2, 0, 0>
        >
    {
    public:

        static  const std::string name;

        static const std::string URI;

        typedef EnvBase<TimeStep<uint_t>,
                        ScalarDiscreteEnv<state_size, 2, 0, 0>
        > base_type;

        typedef typename base_type::time_step_type time_step_type;

        typedef typename base_type::state_space_type state_space_type;

        typedef typename base_type::action_space_type action_space_type;

        typedef typename base_type::action_type action_type;

        typedef typename base_type::state_type state_type;

        typedef std::vector<std::tuple<real_t, uint_t, real_t, bool>> dynamics_t;

        GymWalk(const RESTApiServerWrapper& api_server);

        GymWalk(const RESTApiServerWrapper& api_server,
                const uint_t cidx);

        GymWalk(const GymWalk& other);

        virtual void make(const std::string& version,
                          const std::unordered_map<std::string, std::any>& options) override final;


        virtual bool is_alive()const override final;

        virtual void close() override final;

        virtual time_step_type step(const action_type& action) override final;

        virtual time_step_type reset(uint_t seed,
                                     const std::unordered_map<std::string, std::any>& options)override final;


        virtual std::unique_ptr<base_type> make_copy(uint_t cidx)const override final;


        uint_t n_states()const noexcept{ return state_space_type::size; }

        uint_t n_actions()const noexcept{return action_space_type::size;}


    private:

        dynamics_t build_dynamics_from_response_(const nlohmann::json& response)const;

        time_step_type create_time_step_from_response_(const nlohmann::json& response) const;


        RESTApiServerWrapper api_server_;
    };

    template<uint_t state_size>
    const std::string GymWalk<state_size>::name = "GymWalk";
    const std::string GymWalk<state_size>::URI =  "/gdrl/gym-walk-env";


    template<uint_t state_size>
    GymWalk<state_size>::GymWalk(const RESTApiServerWrapper& api_server)
        :
        EnvBase<TimeStep<uint_t>,
                ScalarDiscreteEnv<state_size, 2, 0, 0>
        >(0, GymWalk<state_size>::name),
        api_server_(api_server_)
    {
        api_server_.register_if_not(GymWalk<state_size>::name, GymWalk<state_size>::URI);
    }

    template<uint_t state_size>
    GymWalk<state_size>::GymWalk(const RESTApiServerWrapper& api_server,
                                 const uint_t cidx)
        :
        EnvBase<TimeStep<uint_t>,
                ScalarDiscreteEnv<state_size, 2, 0, 0>
        >(cidx, "GymWalk"),
        api_server_(api_server)
    {
        api_server_.register_if_not(GymWalk<state_size>::name, GymWalk<state_size>::URI);
    }


    template<uint_t state_size>
    GymWalk<state_size>::GymWalk(const GymWalk<state_size>& other)
        :
        EnvBase<TimeStep<uint_t>,
                ScalarDiscreteEnv<state_size, 2, 0, 0>
        >(other),
        api_server_(other.api_server_)
    {}

    template<uint_t state_size>
    typename GymWalk<state_size>::dynamics_t
    GymWalk<state_size>::build_dynamics_from_response_(const nlohmann::json& response)const{
        auto dynamics = response["dynamics"];
        return dynamics;
    }

    template<uint_t state_size>
    typename GymWalk<state_size>::time_step_type
    GymWalk<state_size>::create_time_step_from_response_(const nlohmann::json& response)const{

        auto step_type   = response["time_step"]["step_type"];
        auto reward      = response["time_step"]["reward"];
        auto discount    = response["time_step"]["discount"];
        auto observation = response["time_step"]["observation"];
        auto info        = response["time_step"]["info"];
        return GymWalk::time_step_type(time_step_type_from_int(step_type),
                                       reward, observation, discount,
                                       std::unordered_map<std::string, std::any>());
    }

    template<uint_t state_size>
    void
    GymWalk<state_size>::make(const std::string& version,
                              const std::unordered_map<std::string, std::any>& options){

        if(this->is_created()){
            return;
        }

        auto response = api_server_.make(this->env_name(),
                                         this->cidx(),
                                         version,
                                         ops);

        this->set_version_(version);
        this->make_created_();
    }

    template<uint_t state_size>
    typename GymWalk<state_size>::time_step_type
    GymWalk<state_size>::step(const action_type&  action){

#ifdef RLENVSCPP_DEBUG
     assert(this->is_created() && "Environment has not been created");
#endif

        if(this->get_current_time_step_().last()){
            return this->reset(42, std::unordered_map<std::string, std::any>());
        }

        const auto response  = api_server_.step(this -> env_name(),
                                                this -> cidx(),
                                                action);

        this->get_current_time_step_() = this->create_time_step_from_response_(response);
        return this->get_current_time_step_();
    }

    template<uint_t state_size>
    bool
    GymWalk<state_size>::is_alive()const{
        auto response = this -> api_server_.is_alive(this->env_name(), this -> cidx());
        return response["result"];

    }

    template<uint_t state_size>
    bool
    GymWalk<state_size>::close(){

        if(!this->is_created()){
            return;
        }

        auto response = this -> api_server_.close(this->env_name(), this -> cidx());
        this -> invalidate_is_created_flag_();
    }

    template<uint_t state_size>
    typename GymWalk<state_size>::time_step_type
    GymWalk<state_size>::reset(uint_t seed,
                               const std::unordered_map<std::string, std::any>& /*options*/){

        if(!this->is_created()){
#ifdef RLENVSCPP_DEBUG
     assert(this->is_created() && "Environment has not been created");
#endif
            return time_step_type();
        }

        auto response = this -> api_server_.reset(this->env_name(),
                                                  this -> cidx(), seed,
                                                  nlohmann::json());

        this->create_time_step_from_response_(response);
        return this -> get_current_time_step_();
    }

    template<uint_t state_size>
    GymWalk<state_size>
    GymWalk<state_size>::make_copy(uint_t cidx)const{

        GymWalk<state_size> copy(api_server_ ,cidx);
        std::unordered_map<std::string, std::any> ops;
        auto version = this -> version();
        copy.make(version, ops);
        return copy;
    }

}
}

#endif // GYM_WALK_H