blob: 2bebd470b9fd1897fdaaf7b43e1aa92971c0ea8a [file] [log] [blame] [edit]
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright OpenBMC Authors
#pragma once
#include "async_resp.hpp"
#include "dbus_privileges.hpp"
#include "http_request.hpp"
#include "http_response.hpp"
#include "logging.hpp"
#include "routing/baserule.hpp"
#include "routing/dynamicrule.hpp"
#include "routing/taggedrule.hpp"
#include "routing/trie.hpp"
#include "verb.hpp"
#include <boost/beast/http/field.hpp>
#include <boost/beast/http/status.hpp>
#include <algorithm>
#include <array>
#include <cerrno>
#include <cstdint>
#include <cstdlib>
#include <format>
#include <functional>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <tuple>
#include <utility>
#include <vector>
namespace crow
{
class Router
{
public:
Router() = default;
DynamicRule& newRuleDynamic(const std::string& rule)
{
std::unique_ptr<DynamicRule> ruleObject =
std::make_unique<DynamicRule>(rule);
DynamicRule* ptr = ruleObject.get();
allRules.emplace_back(std::move(ruleObject));
return *ptr;
}
template <uint64_t NumArgs>
auto& newRuleTagged(const std::string& rule)
{
if constexpr (NumArgs == 0)
{
using RuleT = TaggedRule<>;
std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
RuleT* ptr = ruleObject.get();
allRules.emplace_back(std::move(ruleObject));
return *ptr;
}
else if constexpr (NumArgs == 1)
{
using RuleT = TaggedRule<std::string>;
std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
RuleT* ptr = ruleObject.get();
allRules.emplace_back(std::move(ruleObject));
return *ptr;
}
else if constexpr (NumArgs == 2)
{
using RuleT = TaggedRule<std::string, std::string>;
std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
RuleT* ptr = ruleObject.get();
allRules.emplace_back(std::move(ruleObject));
return *ptr;
}
else if constexpr (NumArgs == 3)
{
using RuleT = TaggedRule<std::string, std::string, std::string>;
std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
RuleT* ptr = ruleObject.get();
allRules.emplace_back(std::move(ruleObject));
return *ptr;
}
else if constexpr (NumArgs == 4)
{
using RuleT =
TaggedRule<std::string, std::string, std::string, std::string>;
std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
RuleT* ptr = ruleObject.get();
allRules.emplace_back(std::move(ruleObject));
return *ptr;
}
else
{
using RuleT = TaggedRule<std::string, std::string, std::string,
std::string, std::string>;
std::unique_ptr<RuleT> ruleObject = std::make_unique<RuleT>(rule);
RuleT* ptr = ruleObject.get();
allRules.emplace_back(std::move(ruleObject));
return *ptr;
}
static_assert(NumArgs <= 5, "Max number of args supported is 5");
}
struct PerMethod
{
std::vector<BaseRule*> rules;
Trie<crow::Node> trie;
// rule index 0 has special meaning; preallocate it to avoid
// duplication.
PerMethod() : rules(1) {}
void internalAdd(std::string_view rule, BaseRule* ruleObject)
{
rules.emplace_back(ruleObject);
trie.add(rule, static_cast<unsigned>(rules.size() - 1U));
// directory case:
// request to `/about' url matches `/about/' rule
if (rule.size() > 2 && rule.back() == '/')
{
trie.add(rule.substr(0, rule.size() - 1),
static_cast<unsigned>(rules.size() - 1));
}
}
};
void internalAddRuleObject(const std::string& rule, BaseRule* ruleObject)
{
if (ruleObject == nullptr)
{
return;
}
for (size_t method = 0; method <= maxVerbIndex; method++)
{
size_t methodBit = 1 << method;
if ((ruleObject->methodsBitfield & methodBit) > 0U)
{
perMethods[method].internalAdd(rule, ruleObject);
}
}
if (ruleObject->isNotFound)
{
notFoundRoutes.internalAdd(rule, ruleObject);
}
if (ruleObject->isMethodNotAllowed)
{
methodNotAllowedRoutes.internalAdd(rule, ruleObject);
}
if (ruleObject->isUpgrade)
{
upgradeRoutes.internalAdd(rule, ruleObject);
}
}
void validate()
{
for (std::unique_ptr<BaseRule>& rule : allRules)
{
if (rule)
{
std::unique_ptr<BaseRule> upgraded = rule->upgrade();
if (upgraded)
{
rule = std::move(upgraded);
}
rule->validate();
internalAddRuleObject(rule->rule, rule.get());
}
}
for (PerMethod& perMethod : perMethods)
{
perMethod.trie.validate();
}
}
struct FindRoute
{
BaseRule* rule = nullptr;
std::vector<std::string> params;
};
struct FindRouteResponse
{
std::string allowHeader;
FindRoute route;
};
static FindRoute findRouteByPerMethod(std::string_view url,
const PerMethod& perMethod)
{
FindRoute route;
Trie<crow::Node>::FindResult found = perMethod.trie.find(url);
if (found.ruleIndex >= perMethod.rules.size())
{
throw std::runtime_error("Trie internal structure corrupted!");
}
// Found a 404 route, switch that in
if (found.ruleIndex != 0U)
{
route.rule = perMethod.rules[found.ruleIndex];
route.params = std::move(found.params);
}
return route;
}
FindRouteResponse findRoute(const Request& req) const
{
FindRouteResponse findRoute;
// Check to see if this url exists at any verb
for (size_t perMethodIndex = 0; perMethodIndex <= maxVerbIndex;
perMethodIndex++)
{
// Make sure it's safe to deference the array at that index
static_assert(
maxVerbIndex < std::tuple_size_v<decltype(perMethods)>);
FindRoute route = findRouteByPerMethod(req.url().encoded_path(),
perMethods[perMethodIndex]);
if (route.rule == nullptr)
{
continue;
}
if (!findRoute.allowHeader.empty())
{
findRoute.allowHeader += ", ";
}
HttpVerb thisVerb = static_cast<HttpVerb>(perMethodIndex);
findRoute.allowHeader += httpVerbToString(thisVerb);
}
std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
if (!verb)
{
return findRoute;
}
size_t reqMethodIndex = static_cast<size_t>(*verb);
if (reqMethodIndex >= perMethods.size())
{
return findRoute;
}
FindRoute route = findRouteByPerMethod(req.url().encoded_path(),
perMethods[reqMethodIndex]);
if (route.rule != nullptr)
{
findRoute.route = route;
}
return findRoute;
}
template <typename Adaptor>
void handleUpgrade(const std::shared_ptr<Request>& req,
const std::shared_ptr<bmcweb::AsyncResp>& asyncResp,
Adaptor&& adaptor)
{
PerMethod& perMethod = upgradeRoutes;
Trie<crow::Node>& trie = perMethod.trie;
std::vector<BaseRule*>& rules = perMethod.rules;
Trie<crow::Node>::FindResult found =
trie.find(req->url().encoded_path());
unsigned ruleIndex = found.ruleIndex;
if (ruleIndex == 0U)
{
BMCWEB_LOG_DEBUG("Cannot match rules {}",
req->url().encoded_path());
asyncResp->res.result(boost::beast::http::status::not_found);
return;
}
if (ruleIndex >= rules.size())
{
throw std::runtime_error("Trie internal structure corrupted!");
}
BaseRule& rule = *rules[ruleIndex];
BMCWEB_LOG_DEBUG("Matched rule (upgrade) '{}'", rule.rule);
// TODO(ed) This should be able to use std::bind_front, but it doesn't
// appear to work with the std::move on adaptor.
validatePrivilege(
req, asyncResp, rule,
[req, &rule, asyncResp,
adaptor = std::forward<Adaptor>(adaptor)]() mutable {
rule.handleUpgrade(*req, asyncResp, std::move(adaptor));
});
}
void handle(const std::shared_ptr<Request>& req,
const std::shared_ptr<bmcweb::AsyncResp>& asyncResp)
{
FindRouteResponse foundRoute = findRoute(*req);
if (foundRoute.route.rule == nullptr)
{
// Couldn't find a normal route with any verb, try looking for a 404
// route
if (foundRoute.allowHeader.empty())
{
foundRoute.route = findRouteByPerMethod(
req->url().encoded_path(), notFoundRoutes);
}
else
{
// See if we have a method not allowed (405) handler
foundRoute.route = findRouteByPerMethod(
req->url().encoded_path(), methodNotAllowedRoutes);
}
}
// Fill in the allow header if it's valid
if (!foundRoute.allowHeader.empty())
{
asyncResp->res.addHeader(boost::beast::http::field::allow,
foundRoute.allowHeader);
}
// If we couldn't find a real route or a 404 route, return a generic
// response
if (foundRoute.route.rule == nullptr)
{
if (foundRoute.allowHeader.empty())
{
asyncResp->res.result(boost::beast::http::status::not_found);
}
else
{
asyncResp->res.result(
boost::beast::http::status::method_not_allowed);
}
return;
}
BaseRule& rule = *foundRoute.route.rule;
std::vector<std::string> params = std::move(foundRoute.route.params);
BMCWEB_LOG_DEBUG("Matched rule '{}' {} / {}", rule.rule,
req->methodString(), rule.getMethods());
if (req->session == nullptr)
{
rule.handle(*req, asyncResp, params);
return;
}
validatePrivilege(
req, asyncResp, rule,
[req, asyncResp, &rule, params = std::move(params)]() {
rule.handle(*req, asyncResp, params);
});
}
void debugPrint()
{
for (size_t i = 0; i < perMethods.size(); i++)
{
BMCWEB_LOG_DEBUG("{}", httpVerbToString(static_cast<HttpVerb>(i)));
perMethods[i].trie.debugPrint();
}
}
std::vector<const std::string*> getRoutes(const std::string& parent)
{
std::vector<const std::string*> ret;
for (const PerMethod& pm : perMethods)
{
std::vector<unsigned> x;
pm.trie.findRouteIndexes(parent, x);
for (unsigned index : x)
{
ret.push_back(&pm.rules[index]->rule);
}
}
return ret;
}
private:
std::array<PerMethod, static_cast<size_t>(HttpVerb::Max)> perMethods;
PerMethod notFoundRoutes;
PerMethod upgradeRoutes;
PerMethod methodNotAllowedRoutes;
std::vector<std::unique_ptr<BaseRule>> allRules;
};
} // namespace crow