Zig 0.16 学习笔记 - 编写一个简易的 HTTP/1.1 服务器

近期我一直在学习 Zig 编程语言,为了熟练掌握 Zig 0.16 的标准库 API,体验一下 Zig 的网络编程,我决定动手实现一个简陋的 HTTP/1.1 服务器。

本次实现不追求高性能、不做复杂封装,只进行最核心的 HTTP 请求解析,同时加入缓冲区限制、请求头数量限制、请求体大小限制等基础安全策略,在简单易懂的前提下,规避常见的恶意请求攻击。

整体思路

本次服务器实现逻辑通俗易懂,核心设计如下:

  • 并发处理:每建立一个客户端连接,单独开启一个线程处理,互不阻塞。
  • 内存管理:每个请求单独创建 ArenaAllocator,该请求下所有内存统一分配,请求结束一次性释放,简化手动内存管理,避免内存泄漏。

代码实现

zig
const std = @import("std");

const REQUEST_HEADERS_BUFFER_SIZE = 4096;
const REQUEST_HEADERS_LIMIT = 100;
const REQUEST_BODY_LIMIT = 1024 * 64;

pub fn main(init: std.process.Init) !void {
    const io = init.io;
    const addr = try std.Io.net.IpAddress.parse("127.0.0.1", 3000);

    var server = try addr.listen(io, .{});
    defer server.deinit(io);

    std.log.info("listen 127.0.0.1:3000", .{});

    while (server.accept(io)) |stream| {
        std.log.info("accept {}", .{stream.socket.address});
        _ = std.Thread.spawn(.{}, handleConnect, .{ io, init.gpa, stream }) catch |err| {
            stream.close(io);
            std.log.err("Thread.spawn error: {}", .{err});
        };
    } else |err| {
        return err;
    }
}

fn handleConnect(io: std.Io, alloc: std.mem.Allocator, stream: std.Io.net.Stream) !void {
    defer stream.close(io);

    const request = try parseHttpRequest(io, alloc, &stream);
    defer request.deinit();

    printRequest(&request);

    try response(io, &stream, request.body);
}

fn printRequest(request: *const Request) void {
    std.log.info("request: vvvvvvvvvvvvvvvvvvvvvvvvvv", .{});
    std.log.info("         {s} {s} {s}", .{ request.method.asStr(), request.path, request.version.asStr() });
    var headers = request.headers.iterator();
    while (headers.next()) |header| {
        std.log.info("         {s}: {s}", .{ header.key_ptr.*, header.value_ptr.* });
    }
    std.log.info("", .{});
    std.log.info("         {s}", .{request.body});
    std.log.info("         ^^^^^^^^^^^^^^^^^^^^^^^^^^", .{});
}

fn response(io: std.Io, stream: *const std.Io.net.Stream, message: []const u8) !void {
    var writer_buf: [4096]u8 = undefined;
    var writer = stream.writer(io, &writer_buf);
    try writer.interface.print("HTTP/1.1 200 OK\r\nConnection: close\r\ncontent-length: {}\r\n\r\n{s}", .{
        message.len,
        message,
    });
    try writer.interface.flush();
}

fn parseHttpRequest(io: std.Io, alloc: std.mem.Allocator, stream: *const std.Io.net.Stream) !Request {
    var arena = try alloc.create(std.heap.ArenaAllocator);
    errdefer alloc.destroy(arena);

    arena.* = std.heap.ArenaAllocator.init(alloc);
    errdefer arena.deinit();

    const arena_alloc = arena.allocator();

    var reader_buf: [4096]u8 = undefined;
    var reader = stream.reader(io, &reader_buf);

    var buf: [REQUEST_HEADERS_BUFFER_SIZE]u8 = undefined;
    var vec = [_][]u8{&buf};
    const n = try reader.interface.readVec(&vec);

    const headers_len = std.mem.indexOf(u8, buf[0..n], "\r\n\r\n") orelse {
        return error.ParseError;
    };

    var request: Request = undefined;
    request.arena = arena;
    request.version = .http1_1;
    request.headers = .init(arena_alloc);
    request.body = "";

    var lines = std.mem.splitSequence(u8, buf[0..headers_len], "\r\n");
    if (lines.next()) |line| {
        var request_line = std.mem.splitScalar(u8, line, ' ');

        request.method = blk: {
            const text = request_line.next() orelse {
                return error.ParseError;
            };
            const method = try HttpMethod.fromStr(arena_alloc, text);
            break :blk method;
        };
        request.path = blk: {
            const text = request_line.next() orelse {
                return error.ParseError;
            };
            break :blk try arena_alloc.dupe(u8, text);
        };
        request.version = blk: {
            const text = request_line.next() orelse {
                return error.ParseError;
            };
            const version = try HttpVersion.fromStr(text);
            break :blk version;
        };
    }
    while (lines.next()) |header| {
        if (header.len == 0) {
            break;
        }

        if (request.headers.count() == REQUEST_HEADERS_LIMIT) {
            return error.TooManyHeaders;
        }

        const index = std.mem.indexOf(u8, header, ":") orelse {
            return error.ParseError;
        };

        const header_name = try std.ascii.allocLowerString(arena_alloc, header[0..index]);
        const header_value = if (index + 1 < header.len)
            try arena_alloc.dupe(u8, std.mem.trim(u8, header[index + 1 ..], " "))
        else
            "";

        try request.headers.put(header_name, header_value);
    }

    const content_length = request.headers.get("content-length") orelse "0";
    const body_size = try std.fmt.parseInt(usize, content_length, 10);

    if (body_size == 0) {
        return request;
    }
    if (body_size > REQUEST_BODY_LIMIT) {
        return error.DataTooLong;
    }

    const remaining = @min(n - (headers_len + 4), body_size);

    const body = try arena_alloc.alloc(u8, body_size);
    @memcpy(body, buf[headers_len + 4 .. headers_len + 4 + remaining]);

    try reader.interface.readSliceAll(body[remaining..]);

    request.body = body;

    return request;
}

const Request = struct {
    version: HttpVersion,
    method: HttpMethod,
    path: []u8,
    headers: std.StringHashMap([]const u8),
    body: []u8,
    arena: *std.heap.ArenaAllocator,

    pub fn deinit(self: *const Request) void {
        self.arena.deinit();
        self.arena.child_allocator.destroy(self.arena);
    }
};

const HttpVersion = enum {
    http1_1,
    http2,
    http3,

    const TABLE = [_]struct { name: []const u8, version: HttpVersion }{
        .{ .name = "HTTP/1.1", .version = .http1_1 },
        .{ .name = "HTTP/2", .version = .http2 },
        .{ .name = "HTTP/3", .version = .http3 },
    };

    pub fn fromStr(str: []const u8) !HttpVersion {
        inline for (TABLE) |pair| {
            if (std.mem.eql(u8, str, pair.name)) {
                return pair.version;
            }
        }
        return error.ParseError;
    }

    pub fn asStr(self: HttpVersion) []const u8 {
        return switch (self) {
            .http1_1 => "HTTP/1.1",
            .http2 => "HTTP/2",
            .http3 => "HTTP/3",
        };
    }
};

const HttpMethod = union(enum) {
    get,
    post,
    put,
    patch,
    delete,
    connect,
    head,
    options,
    other: []const u8,

    pub fn fromStr(alloc: std.mem.Allocator, str: []const u8) !HttpMethod {
        if (str.len == 0) return error.ParseError;

        inline for (@typeInfo(HttpMethod).@"union".fields) |field| {
            if (field.type == void) {
                if (std.ascii.eqlIgnoreCase(field.name, str)) {
                    return @unionInit(HttpMethod, field.name, {});
                }
            }
        }

        return .{ .other = try alloc.dupe(u8, str) };
    }

    pub fn asStr(self: *const HttpMethod) []const u8 {
        return switch (self.*) {
            .get => "GET",
            .post => "POST",
            .put => "PUT",
            .patch => "PATCH",
            .delete => "DELETE",
            .connect => "CONNECT",
            .head => "HEAD",
            .options => "OPTIONS",
            .other => |method| method,
        };
    }
};