diff --git a/common/rulesys.cpp b/common/rulesys.cpp index 13e342fe5..670c090b5 100644 --- a/common/rulesys.cpp +++ b/common/rulesys.cpp @@ -236,7 +236,7 @@ void RuleManager::SaveRules(Database *database, const char *ruleset_name) { } bool RuleManager::LoadRules(Database *database, const char *ruleset_name) { - + int ruleset_id = GetRulesetID(database, ruleset_name); if (ruleset_id < 0) { Log.Out(Logs::Detail, Logs::Rules, "Failed to find ruleset '%s' for load operation. Canceling.", ruleset_name); @@ -248,6 +248,26 @@ bool RuleManager::LoadRules(Database *database, const char *ruleset_name) { m_activeRuleset = ruleset_id; m_activeName = ruleset_name; + /* Load default ruleset values first if we're loading something other than default */ + if (strcasecmp(ruleset_name, "default") != 0){ + std::string default_ruleset_name = "default"; + int default_ruleset_id = GetRulesetID(database, default_ruleset_name.c_str()); + if (default_ruleset_id < 0) { + Log.Out(Logs::Detail, Logs::Rules, "Failed to find default ruleset '%s' for load operation. Canceling.", default_ruleset_name.c_str()); + return(false); + } + Log.Out(Logs::Detail, Logs::Rules, "Loading rule set '%s' (%d)", default_ruleset_name, default_ruleset_id); + + std::string query = StringFormat("SELECT rule_name, rule_value FROM rule_values WHERE ruleset_id = %d", default_ruleset_id); + auto results = database->QueryDatabase(query); + if (!results.Success()) + return false; + + for (auto row = results.begin(); row != results.end(); ++row) + if (!SetRule(row[0], row[1], nullptr, false)) + Log.Out(Logs::Detail, Logs::Rules, "Unable to interpret rule record for %s", row[0]); + } + std::string query = StringFormat("SELECT rule_name, rule_value FROM rule_values WHERE ruleset_id=%d", ruleset_id); auto results = database->QueryDatabase(query); if (!results.Success())