Files
hate/Router.cpp
2025-11-22 16:25:58 +01:00

507 lines
10 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "Router.h"
#include <algorithm>
#include <streambuf>
#include <iostream> // Only for warning when handling client error
#include <istream>
#include <ostream>
#include <fstream>
#include <vector>
#include <string>
#include <sstream>
#include <map>
#include "Socket.h"
#include "StringUtil.h"
static std::map<int, const char*> STATUS_STR =
{
{ 100, "Continue" },
{ 101, "Switching Protocols" },
{ 102, "Processing" },
{ 103, "Early Hints" },
{ 104, "Upload Resumption Supported" },
{ 200, "OK" },
{ 201, "Created" },
{ 202, "Accepted" },
{ 203, "Non - Authoritative Information" },
{ 204, "No Content" },
{ 205, "Reset Content" },
{ 206, "Partial Content" },
{ 207, "Multi - Status" },
{ 208, "Already Reported" },
{ 226, "IM Used" },
{ 300, "Multiple Choices" },
{ 301, "Moved Permanently" },
{ 302, "Found" },
{ 303, "See Other" },
{ 304, "Not Modified" },
{ 305, "Use Proxy" },
{ 306, "Switch Proxy" },
{ 307, "Temporary Redirect" },
{ 308, "Permanent Redirect" },
{ 400, "Bad Request" },
{ 401, "Unauthorized" },
{ 402, "Payment Required" },
{ 403, "Forbidden" },
{ 404, "Not Found" },
{ 405, "Method Not Allowed" },
{ 406, "Not Acceptable" },
{ 407, "Proxy Authentication Required" },
{ 408, "Request Timeout" },
{ 409, "Conflict" },
{ 410, "Gone" },
{ 411, "Length Required" },
{ 412, "Precondition Failed" },
{ 413, "Content Too Large" },
{ 414, "URI Too Long" },
{ 415, "Unsupported Media Type" },
{ 416, "Range Not Satisfiable" },
{ 417, "Expectation Failed" },
{ 418, "Im a Teapot" },
{ 421, "Misdirected Request" },
{ 422, "Unprocessable Content" },
{ 423, "Locked" },
{ 424, "Failed Dependency" },
{ 425, "Too Early" },
{ 426, "Upgrade Required" },
{ 427, "Unassigned" },
{ 428, "Precondition Required" },
{ 429, "Too Many Requests" },
{ 430, "Unassigned" },
{ 431, "Request Header Fields Too Large" },
{ 451, "Unavailable For Legal Reasons" },
{ 500, "Internal Server Error" },
{ 501, "Not Implemented" },
{ 502, "Bad Gateway" },
{ 503, "Service Unavailable" },
{ 504, "Gateway Timeout" },
{ 505, "HTTP Version Not Supported" },
{ 506, "Variant Also Negotiates" },
{ 507, "Insufficient Storage" },
{ 508, "Loop Detected" },
{ 509, "Unassigned" },
{ 510, "Not Extended" },
{ 511, "Network Authentication Required" },
};
namespace
{
class ClientError : public std::exception
{
public:
ClientError(const char* what) : std::exception(what)
{
}
};
class ReqBuf : public std::streambuf
{
public:
ReqBuf(std::istream& stream) : _stream(stream)
{
constexpr int header_max = 64;
constexpr int line_max = 512;
for (int n = 0;; n++)
{
if (n == header_max)
{
throw ClientError("Too many headers");
}
std::string line;
line += (char)_stream.get();
line += (char)_stream.get();
while (line.substr(line.size() - 2, 2) != "\r\n")
{
int byte = _stream.get();
if (byte == EOF)
{
throw ClientError("Connection closed prmaturely");
}
if (line.size() > line_max)
{
throw ClientError("Header field too long");
}
line += (char)byte;
}
if (line == "\r\n")
break;
line = line.substr(0, line.size() - 2);
if (n == 0)
{
std::vector<std::string> req = StringUtil::splitClean(line, ' ');
method = req[0];
url = req[1];
http_version = req[2];
std::transform(method.begin(), method.end(), method.begin(), ::toupper);
}
else
{
size_t sep = line.find(':');
std::string name = StringUtil::strip(line.substr(0, sep));
std::string value = StringUtil::strip(line.substr(sep + 1));
headers[name] = value;
}
}
}
virtual int underflow()
{
return _stream.get();
}
std::string method, url, http_version;
std::map<std::string, std::string> headers;
private:
std::istream& _stream;
};
}
#define buf ((ReqBuf*)rdbuf())
Request::Request(std::istream& stream) : std::istream(new ReqBuf(stream))
{
}
Request::~Request()
{
delete rdbuf();
}
std::string Request::url() const
{
return buf->url;
}
std::string Request::method() const
{
return buf->method;
}
const std::string* Request::header(std::string name) const
{
auto field = buf->headers.find(name);
if (field == buf->headers.end())
return nullptr;
return &field->second;
}
#undef buf
namespace
{
class ResBuf : public std::streambuf
{
public:
ResBuf(std::ostream& stream) : _stream(stream)
{
}
void send_head()
{
if (head_sent)
throw std::exception("Head has already been sent");
head_sent = true;
_stream
<< "HTTP/1.1 " << status << ' ' << STATUS_STR[status] << "\r\n";
for (auto& pair : headers)
_stream << pair.first << ": " << pair.second << "\r\n";
_stream << "\r\n";
}
int status = 200;
bool head_sent = false;
std::map<std::string, std::string> headers = {
{"Connection", "Close"},
};
protected:
virtual int overflow(int c) override
{
if (!head_sent)
send_head();
_stream.put(c);
return c;
}
virtual int sync() override
{
_stream.flush();
return 0;
}
private:
std::ostream& _stream;
};
}
#define buf ((ResBuf*)rdbuf())
Response::Response(std::ostream& stream) : std::ostream(new ResBuf(stream))
{
}
Response::~Response()
{
delete rdbuf();
}
void Response::status(int status)
{
buf->status = status;
}
void Response::header(std::string name, std::string value)
{
if (buf->head_sent)
throw std::exception("Headers already sent.");
buf->headers[name] = value;
}
void Response::send(std::string text)
{
if (buf->headers.find("Content-Type") == buf->headers.end())
header("Content-Type", "text/plain");
header("Content-Length", text.size());
*this << text;
flush();
}
void Response::send(const void* data, size_t size)
{
if (buf->headers.find("Content-Type") == buf->headers.end())
header("Content-Type", "application/octet-stream");
header("Content-Length", size);
write((const char*)data, size);
flush();
}
void Response::sendFile(std::string filename)
{
std::ifstream fs(filename, std::ifstream::binary | std::ifstream::ate);
if (!fs.is_open())
throw std::exception(("Could not open \""+filename+"\"").c_str());
header("Content-Length", (size_t)fs.tellg());
fs.seekg(0, std::ifstream::beg);
*this << fs.rdbuf();
flush();
}
#undef buf
static void default_404(Request& req, Response& res)
{
std::stringstream ss;
ss << "<html>"
<< "<head>"
<< "<title>Not Found</title>"
<< "</head>"
<< "<body>"
<< "<h1>Not Found</h1>"
<< "<p>"
<< '"' << req.url() << '"'
<< " could not be found."
<< "</p>"
<< "</body>"
<< "</html>";
res.status(404);
res.header("Content-Type", "text/html");
res.send(ss.str());
}
static void default_500(Request& req, Response& res, std::string msg)
{
std::stringstream ss;
ss << "<html>"
<< "<head>"
<< "<title>Error</title>"
<< "</head>"
<< "<body>"
<< "<h1>Error</h1>"
<< "<p>"
<< msg
<< "</p>"
<< "</body>"
<< "</html>";
res.status(500);
res.header("Content-Type", "text/html");
res.send(ss.str());
}
Router::Router()
{
}
void Router::handle(std::istream& req_s, std::ostream& res_s)
{
Request req(req_s);
Response res(res_s);
auto f = _handlers.find(req.url());
if (f == _handlers.end())
{
default_404(req, res);
return;
}
else
{
std::vector<Handler>& handlers = f->second;
auto h = handlers.begin();
jmp_buf next;
setjmp(next);
if (h == handlers.end())
{
default_404(req, res);
return;
}
else
{
try
{
(*h)(req, res, [&]() { h += 1; longjmp(next, 0); });
}
catch (std::exception& ex)
{
default_500(req, res, std::string(ex.what()));
}
catch (std::string& str)
{
default_500(req, res, str);
}
catch (const char* str)
{
default_500(req, res, std::string(str));
}
catch (...)
{
default_500(req, res, "Unhandeled exception has accrued");
}
}
}
}
void Router::on(std::string url, Handler handler)
{
// if vector already existed it is returned.
// If not, a new one is created and returned.
auto pair = _handlers.emplace(std::piecewise_construct, std::make_tuple(url), std::make_tuple());
// First is the key-value pair iterator,
// second is true if new vector was created.
auto& iter = pair.first;
iter
// First is the key, second is the vector.
->second
.push_back(handler);
}
void Router::on_get(std::string url, Handler handler)
{
on(url, [=](Request& req, Response& res, Next next) { if (req.method() == "GET") handler(req, res, next); else next(); });
}
void Router::on_head(std::string url, Handler handler)
{
on(url, [=](Request& req, Response& res, Next next) { if (req.method() == "HEAD") handler(req, res, next); else next(); });
}
void Router::on_options(std::string url, Handler handler)
{
on(url, [=](Request& req, Response& res, Next next) { if (req.method() == "OPTIONS") handler(req, res, next); else next(); });
}
void Router::on_trace(std::string url, Handler handler)
{
on(url, [=](Request& req, Response& res, Next next) { if (req.method() == "TRACE") handler(req, res, next); else next(); });
}
void Router::on_put(std::string url, Handler handler)
{
on(url, [=](Request& req, Response& res, Next next) { if (req.method() == "PUT") handler(req, res, next); else next(); });
}
void Router::on_delete(std::string url, Handler handler)
{
on(url, [=](Request& req, Response& res, Next next) { if (req.method() == "DELETE") handler(req, res, next); else next(); });
}
void Router::on_post(std::string url, Handler handler)
{
on(url, [=](Request& req, Response& res, Next next) { if (req.method() == "POST") handler(req, res, next); else next(); });
}
void Router::on_patch(std::string url, Handler handler)
{
on(url, [=](Request& req, Response& res, Next next) { if (req.method() == "PATCH") handler(req, res, next); else next(); });
}
void Router::on_connect(std::string url, Handler handler)
{
on(url, [=](Request& req, Response& res, Next next) { if (req.method() == "CONNECT") handler(req, res, next); else next(); });
}
void Router::listen(std::string address, short port)
{
Socket server;
server.bind(address.c_str(), port);
server.listen();
while (true)
{
Socket client = server.accept();
SocketStream stream(client);
handle(stream, stream);
}
}
void Router::listen(short port)
{
Socket server;
server.bind(port);
server.listen();
while (true) try
{
Socket client = server.accept();
SocketStream stream(client);
handle(stream, stream);
}
catch (ClientError& err)
{
// Nothing major, ignore
std::cout << "Client Error: " << err.what() << std::endl;
}
}