diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 3b06c0e1f..0a1766fc8 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -62,6 +62,7 @@ SET(common_sources mutex.cpp mysql_request_result.cpp mysql_request_row.cpp + mysql_stmt.cpp opcode_map.cpp opcodemgr.cpp packet_dump.cpp @@ -586,6 +587,7 @@ SET(common_headers mutex.h mysql_request_result.h mysql_request_row.h + mysql_stmt.h op_codes.h opcode_dispatch.h opcodemgr.h diff --git a/common/dbcore.cpp b/common/dbcore.cpp index 2fad0f2a7..82a9bc396 100644 --- a/common/dbcore.cpp +++ b/common/dbcore.cpp @@ -7,6 +7,7 @@ #include "timer.h" #include "dbcore.h" +#include "mysql_stmt.h" #include #include @@ -436,3 +437,8 @@ MySQLRequestResult DBcore::QueryDatabaseMulti(const std::string &query) return r; } + +mysql::PreparedStmt DBcore::Prepare(std::string query) +{ + return mysql::PreparedStmt(*mysql, std::move(query), m_mutex); +} diff --git a/common/dbcore.h b/common/dbcore.h index 3cc206012..cefdc6522 100644 --- a/common/dbcore.h +++ b/common/dbcore.h @@ -17,6 +17,8 @@ #define CR_SERVER_GONE_ERROR 2006 #define CR_SERVER_LOST 2013 +namespace mysql { class PreparedStmt; } + class DBcore { public: enum eStatus { @@ -48,6 +50,11 @@ public: } void SetMutex(Mutex *mutex); + // only safe on connections shared with other threads if results buffered + // unsafe to use off main thread due to internal server logging + // throws std::runtime_error on failure + mysql::PreparedStmt Prepare(std::string query); + protected: bool Open( const char *iHost, diff --git a/common/mysql_stmt.cpp b/common/mysql_stmt.cpp new file mode 100644 index 000000000..0c71aa53c --- /dev/null +++ b/common/mysql_stmt.cpp @@ -0,0 +1,586 @@ +#include "mysql_stmt.h" +#include "eqemu_logsys.h" +#include "mutex.h" +#include "timer.h" +#include + +namespace mysql +{ + +void PreparedStmt::StmtDeleter::operator()(MYSQL_STMT* stmt) noexcept +{ + // The connection must be locked when closing the stmt to avoid mysql errors + // in case another thread tries to use it during the close. If the mutex is + // changed to one that throws then exceptions need to be caught here. + LockMutex lock(mutex); + mysql_stmt_close(stmt); +} + +PreparedStmt::PreparedStmt(MYSQL& mysql, std::string query, Mutex* mutex, StmtOptions opts) + : m_stmt(mysql_stmt_init(&mysql), { mutex }), m_query(std::move(query)), m_mutex(mutex), m_options(opts) +{ + LockMutex lock(m_mutex); + if (mysql_stmt_prepare(m_stmt.get(), m_query.c_str(), static_cast(m_query.size())) != 0) + { + ThrowError(fmt::format("Prepare error: {}", GetStmtError())); + } + + m_params.resize(mysql_stmt_param_count(m_stmt.get())); + m_inputs.resize(m_params.size()); +} + +void PreparedStmt::ThrowError(const std::string& error) +{ + LogMySQLError("{}", error); + throw std::runtime_error(error); +} + +std::string PreparedStmt::GetStmtError() +{ + auto err = mysql_stmt_errno(m_stmt.get()); + auto str = mysql_stmt_error(m_stmt.get()); + return fmt::format("({}) [{}] for query [{}]", err, str, m_query); +} + +template +void PreparedStmt::BindInput(size_t index, T value) +{ + if (index >= m_inputs.size()) + { + ThrowError(fmt::format("Cannot bind input, index {} out of range", index)); + } + + impl::Bind& arg = m_inputs[index]; + arg.is_null = std::is_same_v; + + MYSQL_BIND& bind = m_params[index]; + bind.is_unsigned = std::is_unsigned_v; + bind.is_null = &arg.is_null; + bind.length = &arg.length; + + auto old_type = bind.buffer_type; + + if constexpr (std::is_arithmetic_v) + { + if (arg.buffer.size() < sizeof(T)) + { + arg.buffer.resize(std::max(sizeof(T), sizeof(int64_t))); + bind.buffer = arg.buffer.data(); + m_need_bind = true; + } + memcpy(arg.buffer.data(), &value, sizeof(T)); + } + + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) + { + bind.buffer_type = MYSQL_TYPE_TINY; + } + else if constexpr (std::is_same_v || std::is_same_v) + { + bind.buffer_type = MYSQL_TYPE_SHORT; + } + else if constexpr (std::is_same_v || std::is_same_v) + { + bind.buffer_type = MYSQL_TYPE_LONG; + } + else if constexpr (std::is_same_v || std::is_same_v) + { + bind.buffer_type = MYSQL_TYPE_LONGLONG; + } + else if constexpr (std::is_same_v) + { + bind.buffer_type = MYSQL_TYPE_FLOAT; + } + else if constexpr (std::is_same_v) + { + bind.buffer_type = MYSQL_TYPE_DOUBLE; + } + else if constexpr (std::is_same_v) + { + bind.buffer_type = MYSQL_TYPE_STRING; + if (arg.buffer.empty() || arg.buffer.size() < value.size()) + { + arg.buffer.resize(static_cast((value.size() + 1) * 1.5)); + bind.buffer = arg.buffer.data(); + bind.buffer_length = static_cast(arg.buffer.size()); + m_need_bind = true; + } + std::copy(value.begin(), value.end(), arg.buffer.begin()); + arg.length = static_cast(value.size()); + } + else if constexpr (!std::is_same_v) + { + static_assert(false_v, "Cannot bind unsupported type"); + } + + if (old_type != bind.buffer_type) + { + m_need_bind = true; + } +} + +void PreparedStmt::BindInput(size_t index, const char* str) +{ + BindInput(index, std::string_view(str)); +} + +void PreparedStmt::BindInput(size_t index, const std::string& str) +{ + BindInput(index, std::string_view(str)); +} + +StmtResult PreparedStmt::Execute() +{ + CheckArgs(0); + return DoExecute(); +} + +StmtResult PreparedStmt::Execute(const std::vector& args) +{ + CheckArgs(args.size()); + for (size_t i = 0; i < args.size(); ++i) + { + std::visit([&](const auto& arg) { BindInput(i, arg); }, args[i]); + } + return DoExecute(); +} + +template +StmtResult PreparedStmt::Execute(const std::vector& args) +{ + CheckArgs(args.size()); + for (size_t i = 0; i < args.size(); ++i) + { + BindInput(i, args[i]); + } + return DoExecute(); +} + +void PreparedStmt::CheckArgs(size_t argc) +{ + if (argc != m_params.size()) + { + ThrowError(fmt::format("Bad arg count (got {}, expected {}) for [{}]", argc, m_params.size(), m_query)); + } +} + +StmtResult PreparedStmt::DoExecute() +{ + BenchTimer timer; + LockMutex lock(m_mutex); + + if (m_need_bind && mysql_stmt_bind_param(m_stmt.get(), m_params.data()) != 0) + { + ThrowError(fmt::format("Bind param error: {}", GetStmtError())); + } + + m_need_bind = false; + + if (mysql_stmt_execute(m_stmt.get()) != 0) + { + ThrowError(fmt::format("Execute error: {}", GetStmtError())); + } + + my_bool attr = m_options.use_max_length; + mysql_stmt_attr_set(m_stmt.get(), STMT_ATTR_UPDATE_MAX_LENGTH, &attr); + + if (m_options.buffer_results && mysql_stmt_store_result(m_stmt.get()) != 0) + { + ThrowError(fmt::format("Store result error: {}", GetStmtError())); + } + + // Result buffers are bound on first execute and re-used if needed + if (m_results.empty()) + { + BindResults(); + } + + StmtResult res(m_stmt.get(), m_results.size()); + + if (m_results.empty()) + { + LogMySQLQuery("{} -- ({} row(s) affected) ({:.6f}s)", m_query, res.RowsAffected(), timer.elapsed()); + } + else + { + LogMySQLQuery("{} -- ({} row(s) returned) ({:.6f}s)", m_query, res.RowCount(), timer.elapsed()); + } + + return res; +} + +void PreparedStmt::BindResults() +{ + MYSQL_RES* res = mysql_stmt_result_metadata(m_stmt.get()); + if (!res) + { + return; // did not produce a result set + } + + MYSQL_FIELD* fields = mysql_fetch_fields(res); + m_columns.resize(mysql_num_fields(res)); + m_results.resize(m_columns.size()); + + for (int i = 0; i < static_cast(m_columns.size()); ++i) + { + impl::BindColumn& col = m_columns[i].m_col; + MYSQL_BIND& bind = m_results[i]; + + col.index = i; + col.name = fields[i].name; + col.buffer_type = fields[i].type; + col.is_unsigned = (fields[i].flags & UNSIGNED_FLAG) != 0; + col.buffer.resize(GetResultBufferSize(fields[i])); + + bind.buffer_type = col.buffer_type; + bind.buffer = col.buffer.data(); + bind.buffer_length = static_cast(col.buffer.size()); + bind.is_unsigned = col.is_unsigned; + bind.is_null = &col.is_null; + bind.length = &col.length; + bind.error = &col.error; + } + + mysql_free_result(res); + + if (!m_results.empty() && mysql_stmt_bind_result(m_stmt.get(), m_results.data()) != 0) + { + ThrowError(fmt::format("Bind result error: {}", GetStmtError())); + } +} + +int PreparedStmt::GetResultBufferSize(const MYSQL_FIELD& field) const +{ + switch (field.type) + { + case MYSQL_TYPE_TINY: + return sizeof(int8_t); + case MYSQL_TYPE_SHORT: + return sizeof(int16_t); + case MYSQL_TYPE_INT24: + case MYSQL_TYPE_LONG: + return sizeof(int32_t); + case MYSQL_TYPE_LONGLONG: + return sizeof(int64_t); + case MYSQL_TYPE_FLOAT: + return sizeof(float); + case MYSQL_TYPE_DOUBLE: + return sizeof(double); + case MYSQL_TYPE_TIME: + case MYSQL_TYPE_DATE: + case MYSQL_TYPE_DATETIME: + case MYSQL_TYPE_TIMESTAMP: + return sizeof(MYSQL_TIME); + default: // if max_length is unavailable for strings buffers are resized on fetch + return field.max_length + 1; // ensure valid buffer created + } +} + +StmtRow PreparedStmt::Fetch() +{ + StmtRow row; + if (!m_columns.empty()) + { + int rc = mysql_stmt_fetch(m_stmt.get()); + if (rc == 1) + { + ThrowError(fmt::format("Fetch error: {}", GetStmtError())); + } + + if (rc != MYSQL_NO_DATA) + { + if (rc == MYSQL_DATA_TRUNCATED) + { + FetchTruncated(); + } + row = StmtRow(m_columns); + } + } + return row; +} + +void PreparedStmt::FetchTruncated() +{ + for (int i = 0; i < static_cast(m_columns.size()); ++i) + { + impl::BindColumn& col = m_columns[i].m_col; + if (col.error) + { + MYSQL_BIND& bind = m_results[i]; + col.buffer.resize(static_cast(col.length * 1.5)); + bind.buffer = col.buffer.data(); + bind.buffer_length = static_cast(col.buffer.size()); + + mysql_stmt_fetch_column(m_stmt.get(), &bind, i, 0); + } + } + + if (mysql_stmt_bind_result(m_stmt.get(), m_results.data()) != 0) + { + ThrowError(fmt::format("Fetch rebind result error: {}", GetStmtError())); + } +} + +// --------------------------------------------------------------------------- + +StmtResult::StmtResult(MYSQL_STMT* stmt, size_t columns) +{ + m_num_cols = static_cast(columns); + m_num_rows = mysql_stmt_num_rows(stmt); // requires buffered results + m_affected = mysql_stmt_affected_rows(stmt); + m_insert_id = mysql_stmt_insert_id(stmt); +} + +// --------------------------------------------------------------------------- + +const StmtColumn* StmtRow::GetColumn(size_t index) const +{ + return index < m_columns.size() ? &m_columns[index] : nullptr; +} + +const StmtColumn* StmtRow::GetColumn(std::string_view name) const +{ + auto it = std::ranges::find_if(m_columns, + [name](const StmtColumn& col) { return col.Name() == name; }); + + return it != m_columns.end() ? &(*it) : nullptr; +} + +std::optional StmtRow::operator[](size_t index) const +{ + return GetStr(index); +} + +std::optional StmtRow::operator[](std::string_view name) const +{ + return GetStr(name); +} + +std::optional StmtRow::GetStr(size_t index) const +{ + const StmtColumn* col = GetColumn(index); + return col ? col->GetStr() : std::nullopt; +} + +std::optional StmtRow::GetStr(std::string_view name) const +{ + const StmtColumn* col = GetColumn(name); + return col ? col->GetStr() : std::nullopt; +} + +template requires std::is_arithmetic_v +std::optional StmtRow::Get(size_t index) const +{ + const StmtColumn* col = GetColumn(index); + return col ? col->Get() : std::nullopt; +} + +template requires std::is_arithmetic_v +std::optional StmtRow::Get(std::string_view name) const +{ + const StmtColumn* col = GetColumn(name); + return col ? col->Get() : std::nullopt; +} + +// --------------------------------------------------------------------------- + +static time_t MakeTime(const MYSQL_TIME& mt) +{ + // buffer mt given in mysql session time zone (assumes local) + std::tm tm{}; + tm.tm_year = mt.year - 1900; + tm.tm_mon = mt.month - 1; + tm.tm_mday = mt.day; + tm.tm_hour = mt.hour; + tm.tm_min = mt.minute; + tm.tm_sec = mt.second; + tm.tm_isdst = -1; + return std::mktime(&tm); +} + +static int MakeSeconds(const MYSQL_TIME& mt) +{ + return (mt.neg ? -1 : 1) * static_cast(mt.hour * 3600 + mt.minute * 60 + mt.second); +} + +static uint64_t MakeBits(std::span data) +{ + // byte stream for bits is in big endian + uint64_t bits = 0; + for (size_t i = 0; i < data.size() && i < sizeof(uint64_t); ++i) + { + bits |= static_cast(data[data.size() - i - 1] & 0xff) << (i * 8); + } + return bits; +} + +template +static T FromString(std::string_view sv) +{ + if constexpr (std::is_same_v) + { + // return false for empty (zero-length) strings + return !sv.empty(); + } + else + { + // non numbers return a zero initialized T (could return nullopt instead) + T value = {}; + std::from_chars(sv.data(), sv.data() + sv.size(), value); + return value; + } +} + +static std::string FormatTime(enum_field_types type, const MYSQL_TIME& mt) +{ + switch (type) + { + case MYSQL_TYPE_TIME: // hhh:mm:ss '-838:59:59' to '838:59:59' + return fmt::format("{}{:02d}:{:02d}:{:02d}", mt.neg ? "-" : "", mt.hour, mt.minute, mt.second); + case MYSQL_TYPE_DATE: // YYYY-MM-DD '1000-01-01' to '9999-12-31' + return fmt::format("{}-{:02d}-{:02d}", mt.year, mt.month, mt.day); + case MYSQL_TYPE_DATETIME: // YYYY-MM-DD hh:mm:ss '1000-01-01 00:00:00' to '9999-12-31 23:59:59' + case MYSQL_TYPE_TIMESTAMP: // YYYY-MM-DD hh:mm:ss '1970-01-01 00:00:01' UTC to '2038-01-19 03:14:07' UTC + return fmt::format("{}-{:02d}-{:02d} {:02d}:{:02d}:{:02d}", mt.year, mt.month, mt.day, mt.hour, mt.minute, mt.second); + default: + return std::string(); + } +} + +std::optional StmtColumn::GetStrView() const +{ + if (m_col.is_null) + { + return std::nullopt; + } + + switch (m_col.buffer_type) + { + case MYSQL_TYPE_NEWDECIMAL: + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + case MYSQL_TYPE_VAR_STRING: + case MYSQL_TYPE_STRING: + return std::make_optional(reinterpret_cast(m_col.buffer.data()), m_col.length); + default: + return std::nullopt; + } +} + +std::optional StmtColumn::GetStr() const +{ + if (m_col.is_null) + { + return std::nullopt; + } + + switch (m_col.buffer_type) + { + case MYSQL_TYPE_TINY: + return m_col.is_unsigned ? fmt::format_int(BitCast()).c_str() : fmt::format_int(BitCast()).c_str(); + case MYSQL_TYPE_SHORT: + return m_col.is_unsigned ? fmt::format_int(BitCast()).c_str() : fmt::format_int(BitCast()).c_str(); + case MYSQL_TYPE_INT24: + case MYSQL_TYPE_LONG: + return m_col.is_unsigned ? fmt::format_int(BitCast()).c_str() : fmt::format_int(BitCast()).c_str(); + case MYSQL_TYPE_LONGLONG: + return m_col.is_unsigned ? fmt::format_int(BitCast()).c_str() : fmt::format_int(BitCast()).c_str(); + case MYSQL_TYPE_FLOAT: + return fmt::format("{}", BitCast()); + case MYSQL_TYPE_DOUBLE: + return fmt::format("{}", BitCast()); + case MYSQL_TYPE_TIME: + case MYSQL_TYPE_DATE: + case MYSQL_TYPE_DATETIME: + case MYSQL_TYPE_TIMESTAMP: + return FormatTime(m_col.buffer_type, BitCast()); + case MYSQL_TYPE_BIT: + return fmt::format_int(*Get()).c_str(); + case MYSQL_TYPE_NEWDECIMAL: + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + case MYSQL_TYPE_VAR_STRING: + case MYSQL_TYPE_STRING: + return std::make_optional(reinterpret_cast(m_col.buffer.data()), m_col.length); + default: + return std::nullopt; + } +} + +template requires std::is_arithmetic_v +std::optional StmtColumn::Get() const +{ + if (m_col.is_null) + { + return std::nullopt; + } + + switch (m_col.buffer_type) + { + case MYSQL_TYPE_TINY: + return m_col.is_unsigned ? static_cast(BitCast()) : static_cast(BitCast()); + case MYSQL_TYPE_SHORT: + return m_col.is_unsigned ? static_cast(BitCast()) : static_cast(BitCast()); + case MYSQL_TYPE_INT24: + case MYSQL_TYPE_LONG: + return m_col.is_unsigned ? static_cast(BitCast()) : static_cast(BitCast()); + case MYSQL_TYPE_LONGLONG: + return m_col.is_unsigned ? static_cast(BitCast()) : static_cast(BitCast()); + case MYSQL_TYPE_FLOAT: + return static_cast(BitCast()); + case MYSQL_TYPE_DOUBLE: + return static_cast(BitCast()); + case MYSQL_TYPE_TIME: // return as total seconds + return static_cast(MakeSeconds(BitCast())); + case MYSQL_TYPE_DATE: + case MYSQL_TYPE_DATETIME: + case MYSQL_TYPE_TIMESTAMP: // return as epoch timestamp + return static_cast(MakeTime(BitCast())); + case MYSQL_TYPE_BIT: + return static_cast(MakeBits({ m_col.buffer.data(), m_col.length })); + case MYSQL_TYPE_NEWDECIMAL: + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + case MYSQL_TYPE_VAR_STRING: + case MYSQL_TYPE_STRING: + return FromString({ reinterpret_cast(m_col.buffer.data()), m_col.length }); + default: + return std::nullopt; + } +} + +// --------------------------------------------------------------------------- + +// explicit template instantiations for supported types +template void PreparedStmt::BindInput(size_t, std::string_view); +template void PreparedStmt::BindInput(size_t, std::nullptr_t); +template StmtResult PreparedStmt::Execute(const std::vector&); +template StmtResult PreparedStmt::Execute(const std::vector&); +template StmtResult PreparedStmt::Execute(const std::vector&); + +#define INSTANTIATE(T) \ + template void PreparedStmt::BindInput(size_t, T); \ + template StmtResult PreparedStmt::Execute(const std::vector&); \ + template std::optional StmtRow::Get(size_t) const; \ + template std::optional StmtRow::Get(std::string_view) const; \ + template std::optional StmtColumn::Get() const; + +INSTANTIATE(bool); +INSTANTIATE(int8_t); +INSTANTIATE(uint8_t); +INSTANTIATE(int16_t); +INSTANTIATE(uint16_t); +INSTANTIATE(int32_t); +INSTANTIATE(uint32_t); +INSTANTIATE(int64_t); +INSTANTIATE(uint64_t); +INSTANTIATE(float); +INSTANTIATE(double); + +} // namespace mysql diff --git a/common/mysql_stmt.h b/common/mysql_stmt.h new file mode 100644 index 000000000..eb1483f17 --- /dev/null +++ b/common/mysql_stmt.h @@ -0,0 +1,221 @@ +#pragma once + +#include "mysql.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class Mutex; + +namespace mysql +{ + +// support MySQL 8.0.1+ API which removed the my_bool type +#if !defined(MARIADB_VERSION_ID) && MYSQL_VERSION_ID >= 80001 +using my_bool = bool; +#endif + +template +inline constexpr bool false_v = false; + +namespace impl +{ + +struct Bind +{ + std::vector buffer; + unsigned long length = 0; + my_bool is_null = false; + my_bool error = false; +}; + +struct BindColumn : Bind +{ + int index = 0; + std::string name; + bool is_unsigned = false; + enum_field_types buffer_type = {}; +}; + +} // namespace impl + +// --------------------------------------------------------------------------- + +struct StmtOptions +{ + // Enable buffering (storing) entire result set after executing a statement + bool buffer_results = true; + + // Enable MySQL to update max_length of fields in execute result set (requires buffering) + bool use_max_length = true; +}; + +// --------------------------------------------------------------------------- + +// Holds ownership of bound column value buffer +class StmtColumn +{ +public: + int Index() const { return m_col.index; } + bool IsNull() const { return m_col.is_null; } + bool IsUnsigned() const { return m_col.is_unsigned; } + enum_field_types Type() const { return m_col.buffer_type; } + const std::string& Name() const { return m_col.name; } + + // Get view of column value buffer + std::span GetBuf() const { return { m_col.buffer.data(), m_col.length }; } + + // Get view of column string value. Returns nullopt if value is NULL or not a string + std::optional GetStrView() const; + + // Get column value as string. Returns nullopt if value is NULL or field type unsupported + std::optional GetStr() const; + + // Get column value as numeric T. Returns nullopt if value NULL or field type unsupported + template requires std::is_arithmetic_v + std::optional Get() const; + +private: + // uses memcpy for type punning buffer data to avoid UB with strict aliasing + template + T BitCast() const + { + T val; + assert(sizeof(T) == m_col.length); + memcpy(&val, m_col.buffer.data(), sizeof(T)); + return val; + } + + friend class PreparedStmt; // access to allocate and bind buffers + friend class StmtResult; // access to resize truncated buffers + impl::BindColumn m_col; +}; + +// --------------------------------------------------------------------------- + +// Provides a non-owning view of PreparedStmt column value buffers +// Evaluates false if it does not contain a valid row +class StmtRow +{ +public: + StmtRow() = default; + StmtRow(std::span columns) : m_columns(columns) {}; + + explicit operator bool() const noexcept { return !m_columns.empty(); } + + int ColumnCount() const { return static_cast(m_columns.size()); } + const StmtColumn* GetColumn(size_t index) const; + const StmtColumn* GetColumn(std::string_view name) const; + + // Get specified column value as string + // Returns nullopt if column invalid, value is NULL, or field type unsupported + std::optional operator[](size_t index) const; + std::optional operator[](std::string_view name) const; + std::optional GetStr(size_t index) const; + std::optional GetStr(std::string_view name) const; + + // Get specified column value as numeric T + // Returns nullopt if column invalid, value is NULL, or field type unsupported + template requires std::is_arithmetic_v + std::optional Get(size_t index) const; + + template requires std::is_arithmetic_v + std::optional Get(std::string_view name) const; + + auto begin() const { return m_columns.begin(); } + auto end() const { return m_columns.end(); } + +private: + std::span m_columns; +}; + +// --------------------------------------------------------------------------- + +// Result meta data for an executed prepared statement +class StmtResult +{ +public: + StmtResult() = default; + StmtResult(MYSQL_STMT* stmt, size_t columns); + + int ColumnCount() const { return m_num_cols; } + uint64_t RowCount() const { return m_num_rows; } + uint64_t RowsAffected() const { return m_affected; } + uint64_t LastInsertID() const { return m_insert_id; } + +private: + int m_num_cols = 0; + uint64_t m_num_rows = 0; + uint64_t m_affected = 0; + uint64_t m_insert_id = 0; +}; + +// --------------------------------------------------------------------------- + +class PreparedStmt +{ +public: + // Supported argument types for execute + using param_t = std::variant; + + PreparedStmt() = delete; + PreparedStmt(MYSQL& mysql, std::string query, Mutex* mutex, StmtOptions opts = {}); + + const std::string& GetQuery() const { return m_query; } + StmtOptions GetOptions() const { return m_options; } + void SetOptions(StmtOptions options) { m_options = options; } + void FreeResult() { mysql_stmt_free_result(m_stmt.get()); } + + // Execute the prepared statement with specified arguments + // Throws exception on error + template + StmtResult Execute(const std::vector& args); + StmtResult Execute(const std::vector& args); + StmtResult Execute(); + + // Fetch the next row into column buffers (overwrites previous row values) + // Return value evaluates false if no more rows to fetch + // Throws exception on error + StmtRow Fetch(); + +private: + void CheckArgs(size_t argc); + StmtResult DoExecute(); + void BindResults(); + void FetchTruncated(); + int GetResultBufferSize(const MYSQL_FIELD& field) const; + void ThrowError(const std::string& error); + std::string GetStmtError(); + + // bind an input value to a query parameter by index + template + void BindInput(size_t index, T value); + void BindInput(size_t index, const char* str); + void BindInput(size_t index, const std::string& str); + + struct StmtDeleter + { + Mutex* mutex = nullptr; + void operator()(MYSQL_STMT* stmt) noexcept; + }; + +private: + std::unique_ptr m_stmt; + std::vector m_params; // input binds + std::vector m_results; // result binds + std::vector m_inputs; // execute buffers (addresses bound) + std::vector m_columns; // fetch buffers (addresses bound) + std::string m_query; + StmtOptions m_options = {}; + bool m_need_bind = true; + Mutex* m_mutex = nullptr; // connection mutex +}; + +} // namespace mysql