Генерация обертки XMVector на шаблонах

2019.08.05

Intro

В библиотеке DirectXMath есть тип XMVector, который использует SIMD инструкции, для обеспечения большей производительности. Есть у него один недостаток - это упакованные данные и для получения привычного представления вектора (как скажем в glm), необходима конвертация через функции XMStore*/XMLoad*. Таких типов достаточно много и писать обертки к каждому дело утомительное, но можно поупражняться в шаблонах для автоматической генерации оберток.

Следует отметить, что использование кода ниже может привести к падению производительности вашего приложения.

Постановка задачи

Типичным представителем пары функций Store/Load являются:

_Use_decl_annotations_ inline XMVECTOR XM_CALLCONV XMLoadFloat2(const XMFLOAT2* pSource)
_Use_decl_annotations_ inline void     XM_CALLCONV XMStoreFloat2(XMFLOAT2* pDestination, FXMVECTOR  V)

Наша задача сгенерировать код класса по такому шаблону:

using Vector2f = VectorImpl<&DirectX::XMLoadFloat2, &DirectX::XMStoreFloat2>;
using Vector3i = VectorImpl<&DirectX::XMLoadSInt3,  &DirectX::XMStoreSInt3>;
using Vector4u = VectorImpl<&DirectX::XMLoadUInt4,  &DirectX::XMStoreUInt4>;

Эти классы должны:

  1. предоставлять доступ к данным через operator[](size_t)
  2. через публичные члены класса .x, .y, .z, .w.
  3. было бы не плохо перегрузить арифметику и
  4. варианты конвертации в другие размерности для данного типа (vec2 -> vec4 -> vec3).

Для этого нам потребуется:

  1. Получить тип вектора (XMFLOAT2, XMFLOAT3, …).
  2. Размерность данного вектора (2, 3 или 4 компонента).
  3. Тип компонента (float, int, unsigned int).
  4. Наследоваться от данного вектора.

Реализация

1. Получение типа вектора.

Будем брать типы из функций преобразования векторов в/из SIMD вид/вида. Для этого можно воспользоваться boost::type_traits:

using VectorN = typename function_traits<DirectX::XMLoadFloat2>::arg1_type

Или напишем свой класс:

template <typename T>
struct FunctionInfo; 

template <typename Res, typename... Args>
struct FunctionInfo<Res XM_CALLCONV(Args...)>
{
    using Result = Res;
    using Arguments = std::tuple<Args...>;
};

using VectorN = std::tuple_element_t<typename FunctionInfo<DirectX::XMLoadFloat2>::Arguments>;

Разберем как это работает:

  1. Сначала определим шаблонный класс без тела (forward declaration).
  2. Специализируем его для конкретного случая (T - сигнатура функции). Стоит отметить, что нужно учитывать все, что относится к сигнатуре включая модификаторы (такие как константность и прочее ). Именно поэтому спецификация шаблона включает XM_CALLCONV. Если вы пишите универсальную библиотеку (как boost::type_traits), вам придется написать тонну специализаций для каждой возможной комбинации.
  3. Сохраняем типы аргументов в псевдониме Arguments как std::tuple т.к. лень писать геттер типа по индексу, а std::tuple имеет все готовое.
  4. Исходный класс должен оставаться неопределенным и генерировать ошибку, в случае попадания в аргумент шаблона неподходящего типа. Можно определить его с static_assert(false, "some human readable description") для большей читаемости.

2. Получение размерности вектора.

Идея, как это сделать, была украдена из замечательной библиотеки magic_get Антона Полухина. Идея в следующем:

  • Каждый класс или структура имеет конструктор (определенный явно или сгенерированный компилятором).
  • Все PODы имеют один конструктор, который принимает в качестве аргументов все поля данной структуры.
  • Совершенно случайно все интересующие нас типы - PODы.

Другими словами, если мы знаем сигнатуру конструктора с максимальным количеством аргументов (помимо него компилятор сгенерирует прочие возможные конструкторы), мы получим число элементов в структуре и их типы. Т.е. задача сводится к нахождению максимальной сигнатуры контруктора для данного PODа. Используя C++ стандарта ниже 17 - это достаточно утомительная задача, но выполнимая (нужно генерировать много однотипного кода). Я использовал следующий вспомогательный класс:

template <auto load, auto store>
struct VectorInfo
{
    using InfoLoad = FunctionInfo<std::remove_pointer_t<decltype(load)>>;
    using InfoStore = FunctionInfo<std::remove_pointer_t<decltype(store)>>;

    // Класс который можно привести к любому типу.
    // Мы не компилируем его, а только используем во время компиляции,
    // поэтому реализация оператора нам не нужна.
    struct Placeholder
    {
        template <class Type> constexpr operator Type() const noexcept;
    };

    // Функция возвращает количество свойств (мемберов) PODа.
    template <typename Class, typename... TArgs>
    static constexpr size_t count()
    {
        // Можем впихнуть больше?
        constexpr const bool is_next = 
            std::is_constructible_v<Class, Placeholder, TArgs...>;
        // Можем вызвать конструктор с текущим набором аргументов?
        constexpr const bool is_current =
             std::is_constructible_v<Class, TArgs...>;

        // Рекурсивно вызываем себя для большего набора аргументов, пока не упремся.
        if constexpr (is_next || !is_current) {
            return count<Class, Placeholder, TArgs...>();
        } else {
            return sizeof...(TArgs);
        }
    }
};

3. Тип компоненты вектора.

Честно говоря, тут я срезал углы т.к. в моем случае не так много вариаций:

// I - количество компонент в векторе (вычисляем в предыдущем пункте).
// Class - тип самого вектора (определяем в пером пункте).
template <size_t I, typename Class>
static constexpr auto returnFirst()
{
    static_assert(I > 1 && I < 5);
    if constexpr (I == 2) {
        auto [x, y] = Class{};
        return x;
    }
    if constexpr (I == 3) {
        auto [x, y, z] = Class{};
        return x;
    }
    if constexpr (I == 4) {
        auto [x, y, z, w] = Class{};
        return x;
    }
}

Для получения всех компонент используется structured binding declaration из C++17. Это работает только для PODов, но нам большего и не нужно. Есть возможность написать универсальную функцию т.к. мы знаем количество компонент и можем сгенерировать выражения в теле if constexpr, но сейчас остановимся на этом.

Эта функция возвращает x для любого типа вектора (наши вектора гомогенны). Получить тип компоненты можно следующим образом:

using Component = decltype(returnFirst<length, VectorN>());

Тут все достаточно стандартно.

4. Наследование и финальный интерфейс.

template <auto load, auto store, typename Info = VectorInfo<load, store>>
struct VectorImpl : Info::VectorN, Info
{
    using Vector = typename Info::Vector;
    using VectorN = typename Info::VectorN;
    using Component = typename Info::Component;
    static const constexpr size_t length = Info::length;

    using VectorN::VectorN;

    constexpr VectorImpl() = default;
    constexpr VectorImpl(const VectorImpl&) = default;
    constexpr VectorImpl(VectorImpl&&) = default;

    VectorImpl(Vector vector) { *this = vector; }
    VectorImpl& operator=(const VectorImpl& vec) = default;
    inline VectorImpl& operator=(Vector vector)
    {
        this->store_(this, vector);
        return *this;
    }


    inline operator VectorN() const { return *this; }
    inline operator Vector() const { return this->load_(this); }
    inline VectorImpl& operator+=(Vector add) { return *this = DirectX::XMVectorAdd(*this, add); }
    inline VectorImpl& operator-=(Vector sub) { return *this = DirectX::XMVectorSubtract(*this, sub); }
    inline VectorImpl& operator*=(Vector mul) { return *this = DirectX::XMVectorMultiply(*this, mul); }
    inline VectorImpl& operator/=(Vector div) { return *this = DirectX::XMVectorDivide(*this, div); }
    inline Component& operator[](size_t i)
    {
        assert(length > i);
        std::byte* begin = reinterpret_cast<std::byte*>(this);
        Component* component = reinterpret_cast<Component*>(begin + i * sizeof(Component));
        return *component;
    }
    inline const Component& operator[](size_t i) const
    {
        assert(length > i);
        const std::byte* begin = reinterpret_cast<const std::byte*>(this);
        const Component* component = reinterpret_cast<const Component*>(begin + i * sizeof(Component));
        return *component;
    }

    friend inline VectorImpl operator+(const VectorImpl& left, Vector right)
    {
        VectorImpl tmp = left;
        tmp += right;
        return tmp;
    }
    friend inline VectorImpl operator-(const VectorImpl& left, Vector right)
    {
        VectorImpl tmp = left;
        tmp -= right;
        return tmp;
    }
    friend inline VectorImpl operator*(const VectorImpl& left, Vector right)
    {
        VectorImpl tmp = left;
        tmp *= right;
        return tmp;
    }
    friend inline VectorImpl operator/(const VectorImpl& left, Vector right)
    {
        VectorImpl tmp = left;
        tmp /= right;
        return tmp;
    }
    friend inline bool operator==(const VectorImpl& left, const VectorImpl& right)
    {
        for (size_t i = 0; i < left.length; ++i) { if (left[i] != right[i]) return false; } return true;
    }
};

Эпилог

Это больше бессмысленно, чем практически полезно. Самое интересное - игры с шаблонами и получение различной информации о типах во время компиляции. Надеюсь, это интересно не только мне.