Program Listing for File time_step.h

Return to documentation for file (src/rlenvs/envs/time_step.h)

#ifndef TIME_STEP_H
#define TIME_STEP_H

#include "rlenvs/rlenvs_types_v2.h"
#include "rlenvs/envs/time_step_type.h"
#include "rlenvs/utils/io/io_utils.h"

#include <string>
#include <any>
#include <unordered_map>
#include <stdexcept>
#include <vector>
#include <ostream>

namespace rlenvscpp {


template<typename StateTp>
class TimeStep
{
public:

    typedef StateTp state_type;


    TimeStep();

    TimeStep(TimeStepTp type, real_t reward, const state_type&  obs);

    TimeStep(TimeStepTp type, real_t reward, const state_type&  obs, real_t discount_factor);

    TimeStep(TimeStepTp type, real_t reward, const state_type& obs,
             real_t discount_factor, std::unordered_map<std::string, std::any>&& extra);

    TimeStep(const TimeStep& other);

    TimeStep& operator=(const TimeStep& other);

    TimeStep(TimeStep&& other)noexcept;

    TimeStep& operator=(TimeStep&& other)noexcept;

    bool first()const noexcept{return type_ == TimeStepTp::FIRST;}

    bool mid()const noexcept{return type_ == TimeStepTp::MID;}

    bool last()const noexcept{return type_ == TimeStepTp::LAST;}

    TimeStepTp type()const noexcept{return type_;}

    state_type observation()const{return obs_;}

    real_t reward()const noexcept{return reward_;}

    real_t discount()const noexcept{return discount_;}

    bool done()const noexcept{return type_ == TimeStepTp::LAST;}

    void clear()noexcept;

    template<typename T>
    const T& get_extra(std::string name)const;

    const std::unordered_map<std::string, std::any>& info()const noexcept{return extra_;}

    std::unordered_map<std::string, std::any>& info()noexcept{return extra_;}

private:

    TimeStepTp type_;

    real_t reward_;

    state_type obs_;

    real_t discount_;

    std::unordered_map<std::string, std::any> extra_;

};

template<typename StateTp>
TimeStep<StateTp>::TimeStep()
    :
      type_(TimeStepTp::INVALID_TYPE),
      reward_(0.0),
      obs_(),
      discount_(1.0),
      extra_()
{}

template<typename StateTp>
TimeStep<StateTp>::TimeStep(TimeStepTp type, real_t reward, const state_type& obs, real_t discount_factor)
    :
      type_(type),
      reward_(reward),
      obs_(obs),
      discount_(discount_factor),
      extra_()
{}

template<typename StateTp>
TimeStep<StateTp>::TimeStep(TimeStepTp type, real_t reward, const state_type& obs)
    :
    TimeStep<StateTp>(type, reward, obs, 1.0)
{}

template<typename StateTp>
TimeStep<StateTp>::TimeStep(TimeStepTp type, real_t reward, const state_type& obs, real_t discount_factor,
                            std::unordered_map<std::string, std::any>&& extra)
    :
    type_(type),
    reward_(reward),
    obs_(obs),
    discount_(discount_factor),
    extra_(extra)
{}

template<typename StateTp>
TimeStep<StateTp>::TimeStep(const TimeStep& other)
    :
      type_(other.type_),
      reward_(other.reward_),
      obs_(other.obs_),
      discount_(other.discount_),
      extra_(other.extra_)
{}

template<typename StateTp>
TimeStep<StateTp>&
TimeStep<StateTp>::operator=(const TimeStep<StateTp>& other){

    type_ = other.type_;
    reward_ = other.reward_;
    obs_ = other.obs_;
    discount_ = other.discount_;
    extra_ = other.extra_;
    return *this;
}

template<typename StateTp>
TimeStep<StateTp>::TimeStep(TimeStep&& other)noexcept
    :
      type_(other.type_),
      reward_(other.reward_),
      obs_(other.obs_),
      discount_(other.discount_),
      extra_(other.extra_)
{
    other.clear();
}

template<typename StateTp>
TimeStep<StateTp>&
TimeStep<StateTp>::operator=(TimeStep&& other)noexcept{

    type_ = other.type_;
    reward_ = other.reward_;
    obs_ = other.obs_;
    discount_ = other.discount_;
    extra_ = other.extra_;
    other.clear();
    return *this;
}

template<typename StateTp>
void
TimeStep<StateTp>::clear()noexcept{

    type_ = TimeStepTp::INVALID_TYPE;
    reward_ = 0.0;
    obs_ = state_type();
    discount_ = 1.0;
    extra_.clear();
}

template<typename StateTp>
template<typename T>
const T&
TimeStep<StateTp>::get_extra(std::string name)const{

    auto itr = extra_.find(name);

    if(itr == extra_.end()){
        throw std::logic_error("Property " + name + " does not exist");
    }

    return std::any_cast<const T&>(itr->second);
}


template<typename StateTp>
inline
std::ostream& operator<<(std::ostream& out, const TimeStep<StateTp>& step){

    out<<"Step type....."<<TimeStepEnumUtils::to_string(step.type())<<std::endl;
    out<<"Reward........"<<step.reward()<<std::endl;
    out<<"Observation..."<<step.observation()<<std::endl;
    out<<"Discount..... "<<step.discount()<<std::endl;
    return out;
}


template<typename T>
std::ostream& operator<<(std::ostream& out,
                         const TimeStep<std::vector<T>>& step){

    out<<"Step type....."<<TimeStepEnumUtils::to_string(step.type())<<std::endl;
    out<<"Reward........"<<step.reward()<<std::endl;

    auto obs = step.observation();

    out<<"Observation...";
    rlenvscpp::utils::io::print_vector(out, obs);

    out<<"Discount..... "<<step.discount()<<std::endl;
    return out;
}

}

#endif // TIME_STEP_H