近期我一直在学习 Zig 编程语言,为了熟练掌握 Zig 0.16 的标准库 API,体验一下 Zig 的网络编程,我决定动手实现一个简陋的 HTTP/1.1 服务器。
本次实现不追求高性能、不做复杂封装,只进行最核心的 HTTP 请求解析,同时加入缓冲区限制、请求头数量限制、请求体大小限制等基础安全策略,在简单易懂的前提下,规避常见的恶意请求攻击。
整体思路
本次服务器实现逻辑通俗易懂,核心设计如下:
- 并发处理:每建立一个客户端连接,单独开启一个线程处理,互不阻塞。
- 内存管理:每个请求单独创建
ArenaAllocator,该请求下所有内存统一分配,请求结束一次性释放,简化手动内存管理,避免内存泄漏。
代码实现
zig
const std = @import("std");
const REQUEST_HEADERS_BUFFER_SIZE = 4096;
const REQUEST_HEADERS_LIMIT = 100;
const REQUEST_BODY_LIMIT = 1024 * 64;
pub fn main(init: std.process.Init) !void {
const io = init.io;
const addr = try std.Io.net.IpAddress.parse("127.0.0.1", 3000);
var server = try addr.listen(io, .{});
defer server.deinit(io);
std.log.info("listen 127.0.0.1:3000", .{});
while (server.accept(io)) |stream| {
std.log.info("accept {}", .{stream.socket.address});
_ = std.Thread.spawn(.{}, handleConnect, .{ io, init.gpa, stream }) catch |err| {
stream.close(io);
std.log.err("Thread.spawn error: {}", .{err});
};
} else |err| {
return err;
}
}
fn handleConnect(io: std.Io, alloc: std.mem.Allocator, stream: std.Io.net.Stream) !void {
defer stream.close(io);
const request = try parseHttpRequest(io, alloc, &stream);
defer request.deinit();
printRequest(&request);
try response(io, &stream, request.body);
}
fn printRequest(request: *const Request) void {
std.log.info("request: vvvvvvvvvvvvvvvvvvvvvvvvvv", .{});
std.log.info(" {s} {s} {s}", .{ request.method.asStr(), request.path, request.version.asStr() });
var headers = request.headers.iterator();
while (headers.next()) |header| {
std.log.info(" {s}: {s}", .{ header.key_ptr.*, header.value_ptr.* });
}
std.log.info("", .{});
std.log.info(" {s}", .{request.body});
std.log.info(" ^^^^^^^^^^^^^^^^^^^^^^^^^^", .{});
}
fn response(io: std.Io, stream: *const std.Io.net.Stream, message: []const u8) !void {
var writer_buf: [4096]u8 = undefined;
var writer = stream.writer(io, &writer_buf);
try writer.interface.print("HTTP/1.1 200 OK\r\nConnection: close\r\ncontent-length: {}\r\n\r\n{s}", .{
message.len,
message,
});
try writer.interface.flush();
}
fn parseHttpRequest(io: std.Io, alloc: std.mem.Allocator, stream: *const std.Io.net.Stream) !Request {
var arena = try alloc.create(std.heap.ArenaAllocator);
errdefer alloc.destroy(arena);
arena.* = std.heap.ArenaAllocator.init(alloc);
errdefer arena.deinit();
const arena_alloc = arena.allocator();
var reader_buf: [4096]u8 = undefined;
var reader = stream.reader(io, &reader_buf);
var buf: [REQUEST_HEADERS_BUFFER_SIZE]u8 = undefined;
var vec = [_][]u8{&buf};
const n = try reader.interface.readVec(&vec);
const headers_len = std.mem.indexOf(u8, buf[0..n], "\r\n\r\n") orelse {
return error.ParseError;
};
var request: Request = undefined;
request.arena = arena;
request.version = .http1_1;
request.headers = .init(arena_alloc);
request.body = "";
var lines = std.mem.splitSequence(u8, buf[0..headers_len], "\r\n");
if (lines.next()) |line| {
var request_line = std.mem.splitScalar(u8, line, ' ');
request.method = blk: {
const text = request_line.next() orelse {
return error.ParseError;
};
const method = try HttpMethod.fromStr(arena_alloc, text);
break :blk method;
};
request.path = blk: {
const text = request_line.next() orelse {
return error.ParseError;
};
break :blk try arena_alloc.dupe(u8, text);
};
request.version = blk: {
const text = request_line.next() orelse {
return error.ParseError;
};
const version = try HttpVersion.fromStr(text);
break :blk version;
};
}
while (lines.next()) |header| {
if (header.len == 0) {
break;
}
if (request.headers.count() == REQUEST_HEADERS_LIMIT) {
return error.TooManyHeaders;
}
const index = std.mem.indexOf(u8, header, ":") orelse {
return error.ParseError;
};
const header_name = try std.ascii.allocLowerString(arena_alloc, header[0..index]);
const header_value = if (index + 1 < header.len)
try arena_alloc.dupe(u8, std.mem.trim(u8, header[index + 1 ..], " "))
else
"";
try request.headers.put(header_name, header_value);
}
const content_length = request.headers.get("content-length") orelse "0";
const body_size = try std.fmt.parseInt(usize, content_length, 10);
if (body_size == 0) {
return request;
}
if (body_size > REQUEST_BODY_LIMIT) {
return error.DataTooLong;
}
const remaining = @min(n - (headers_len + 4), body_size);
const body = try arena_alloc.alloc(u8, body_size);
@memcpy(body, buf[headers_len + 4 .. headers_len + 4 + remaining]);
try reader.interface.readSliceAll(body[remaining..]);
request.body = body;
return request;
}
const Request = struct {
version: HttpVersion,
method: HttpMethod,
path: []u8,
headers: std.StringHashMap([]const u8),
body: []u8,
arena: *std.heap.ArenaAllocator,
pub fn deinit(self: *const Request) void {
self.arena.deinit();
self.arena.child_allocator.destroy(self.arena);
}
};
const HttpVersion = enum {
http1_1,
http2,
http3,
const TABLE = [_]struct { name: []const u8, version: HttpVersion }{
.{ .name = "HTTP/1.1", .version = .http1_1 },
.{ .name = "HTTP/2", .version = .http2 },
.{ .name = "HTTP/3", .version = .http3 },
};
pub fn fromStr(str: []const u8) !HttpVersion {
inline for (TABLE) |pair| {
if (std.mem.eql(u8, str, pair.name)) {
return pair.version;
}
}
return error.ParseError;
}
pub fn asStr(self: HttpVersion) []const u8 {
return switch (self) {
.http1_1 => "HTTP/1.1",
.http2 => "HTTP/2",
.http3 => "HTTP/3",
};
}
};
const HttpMethod = union(enum) {
get,
post,
put,
patch,
delete,
connect,
head,
options,
other: []const u8,
pub fn fromStr(alloc: std.mem.Allocator, str: []const u8) !HttpMethod {
if (str.len == 0) return error.ParseError;
inline for (@typeInfo(HttpMethod).@"union".fields) |field| {
if (field.type == void) {
if (std.ascii.eqlIgnoreCase(field.name, str)) {
return @unionInit(HttpMethod, field.name, {});
}
}
}
return .{ .other = try alloc.dupe(u8, str) };
}
pub fn asStr(self: *const HttpMethod) []const u8 {
return switch (self.*) {
.get => "GET",
.post => "POST",
.put => "PUT",
.patch => "PATCH",
.delete => "DELETE",
.connect => "CONNECT",
.head => "HEAD",
.options => "OPTIONS",
.other => |method| method,
};
}
};