use async_trait::async_trait;
use base64::Engine as _;
use base64::engine::general_purpose;
use bytes::Bytes;
use http::Method;
use pingora::apps::{HttpPersistentSettings, HttpServerApp, HttpServerOptions, ReusedHttpStream};
use pingora::http::{RequestHeader, ResponseHeader};
use pingora::listeners::tls::TlsSettings;
use pingora::prelude::*;
use pingora::protocols::http::ServerSession as DownstreamSession;
use pingora::protocols::http::v2;
use pingora::server::ShutdownWatch;
use pingora::services::listening::Service;
use rand::rngs::StdRng;
use rand::{Rng, RngExt};
use std::env;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
const HELLO_BODY: &str = "hello, world\n";
const AUTH_FAIL_BODY: &str = "proxy auth required\n";
const HTTP2_ONLY_BODY: &str = "http2 required\n";
const BAD_REQUEST_BODY: &str = "bad request\n";
const MAX_FRAME: usize = 65536;
const NUM_FIRST_PADDINGS: usize = 8;
struct ForwardProxy {
username: String,
password: String,
resist_407: bool,
server_options: HttpServerOptions,
}
#[async_trait]
impl HttpServerApp for ForwardProxy {
async fn process_new_http(
self: &Arc<Self>,
mut session: DownstreamSession,
shutdown: &ShutdownWatch,
) -> Option<ReusedHttpStream> {
match session.read_request().await {
Ok(true) => {}
Ok(false) => {
return None;
}
Err(err) => {
log::error!("HTTP server fails to read from downstream: {err}");
return None;
}
}
if *shutdown.borrow() {
session.set_keepalive(None);
}
if session.req_header().method != Method::CONNECT {
if let Err(err) = self.respond_text(&mut session, 200, HELLO_BODY, &[]).await {
log::error!("failed to send hello response: {err}");
}
return self.finish_session(session).await;
}
if !self.check_auth(session.req_header()) {
let result = if self.resist_407 {
session.respond_error(405).await
} else {
self.respond_text(
&mut session,
407,
AUTH_FAIL_BODY,
&[("Proxy-Authenticate", "Basic realm=\"HKJ Secure Web Proxy\"")],
)
.await
};
if let Err(err) = result {
log::error!("failed to send auth response: {err}");
}
return self.finish_session(session).await;
}
match self.handle_connect(&mut session).await {
Ok(should_finish) => {
if should_finish {
return self.finish_session(session).await;
}
None
}
Err(_) => None,
}
}
fn server_options(&self) -> Option<&HttpServerOptions> {
Some(&self.server_options)
}
}
impl ForwardProxy {
fn check_auth(&self, req: &RequestHeader) -> bool {
let Some(value) = req.headers.get("Proxy-Authorization") else {
return false;
};
let Ok(raw) = value.to_str() else {
return false;
};
let mut parts = raw.splitn(2, ' ');
let scheme = parts.next().unwrap_or("");
let token = parts.next().unwrap_or("");
if !scheme.eq_ignore_ascii_case("Basic") {
return false;
}
let Ok(decoded) = general_purpose::STANDARD.decode(token) else {
return false;
};
let Ok(decoded) = std::str::from_utf8(&decoded) else {
return false;
};
decoded == format!("{}:{}", self.username, self.password)
}
async fn respond_text(
&self,
session: &mut DownstreamSession,
status: u16,
body: &str,
extra_headers: &[(&'static str, &'static str)],
) -> Result<()> {
let mut resp = ResponseHeader::build(status, None)?;
resp.insert_header("Content-Type", "text/plain; charset=utf-8")?;
resp.insert_header("Content-Length", body.len().to_string())?;
for (name, value) in extra_headers {
resp.insert_header(*name, *value)?;
}
// session.set_keepalive(None);
session.write_response_header(Box::new(resp)).await?;
let body_bytes = Bytes::copy_from_slice(body.as_bytes());
session.write_response_body(body_bytes, true).await?;
Ok(())
}
async fn finish_session(&self, session: DownstreamSession) -> Option<ReusedHttpStream> {
let persistent_settings = HttpPersistentSettings::for_session(&session);
match session.finish().await {
Ok(stream) => stream.map(|s| ReusedHttpStream::new(s, Some(persistent_settings))),
Err(err) => {
log::error!("HTTP server fails to finish the request: {err}");
None
}
}
}
fn parse_connect_target(&self, req: &RequestHeader) -> Result<(String, u16)> {
let authority = req
.uri
.authority()
// .map(|a| a.as_str().to_string())
// .or_else(|| {
// let path = req.uri.path();
// if !path.is_empty() && !path.starts_with('/') {
// Some(path.to_string())
// } else {
// None
// }
// })
// .or_else(|| {
// req.headers
// .get("Host")
// .and_then(|v| v.to_str().ok())
// .map(|s| s.to_string())
// })
.ok_or_else(|| Error::explain(InvalidHTTPHeader, "missing authority"))?;
// let authority: Authority = authority
// .parse()
// .or_err(InvalidHTTPHeader, "invalid authority")?;
let host = authority.host().to_string();
let port = authority
.port_u16()
.ok_or_else(|| Error::explain(InvalidHTTPHeader, "missing port"))?;
Ok((host, port))
}
fn build_padding_header(rng: &mut impl Rng) -> String {
const PAD_CHARS: &[u8] = b"!#$()+<>?@[]^`{}";
let padding_len = rng.random_range(30..62);
let mut padding = vec![b'~'; padding_len];
let mut bits: u64 = rng.random();
for slot in padding.iter_mut().take(16) {
*slot = PAD_CHARS[(bits & 15) as usize];
bits >>= 4;
}
String::from_utf8(padding).unwrap_or_else(|_| "~".repeat(padding_len))
}
async fn handle_connect(&self, session: &mut DownstreamSession) -> Result<bool> {
let (host, port) = match self.parse_connect_target(session.req_header()) {
Ok(v) => v,
Err(_) => {
session.respond_error(400).await;
return Ok(true);
}
};
if !matches!(&*session, DownstreamSession::H2(_)) {
self.respond_text(session, 505, HTTP2_ONLY_BODY, &[])
.await?;
return Ok(true);
}
let padding_enabled = session.get_header("Padding").is_some();
let mut header_rng: StdRng = rand::make_rng();
let padding_header = Self::build_padding_header(&mut header_rng);
let mut resp = ResponseHeader::build(200, None)?;
resp.insert_header("Padding", padding_header)?;
session.write_response_header(Box::new(resp)).await?;
let upstream = match TcpStream::connect((host.as_str(), port)).await {
Ok(stream) => stream,
Err(_) => {
session.write_response_body(Bytes::new(), true).await?;
return Ok(true);
}
};
let DownstreamSession::H2(h2) = session else {
unreachable!("checked H2 session above");
};
let client_addr = h2
.client_addr()
.map(|addr| addr.to_string())
.unwrap_or_else(|| "<unknown>".to_string());
log::info!(
"CONNECT established from {} to {}:{}",
client_addr,
host,
port
);
let mut resp_writer = h2
.take_response_body_writer()
.ok_or_else(|| Error::explain(H2Error, "missing response body writer"))?;
let (mut upstream_reader, mut upstream_writer) = upstream.into_split();
let client_to_upstream = async {
let mut decoder = PaddingDecoder::new(padding_enabled);
loop {
match h2.read_body_bytes().await? {
Some(chunk) => {
decoder
.write_unpadded(chunk.as_ref(), &mut upstream_writer)
.await?;
}
None => {
upstream_writer
.shutdown()
.await
.or_err(WriteError, "while shutting down upstream writer")?;
return Ok::<(), Box<Error>>(());
}
}
}
};
let upstream_to_client = async {
let mut rng: StdRng = rand::make_rng();
let mut buf = vec![0u8; MAX_FRAME];
let mut frames_left = if padding_enabled {
NUM_FIRST_PADDINGS
} else {
0
};
loop {
if frames_left > 0 {
let padding_size: usize = rng.random_range(0..=255);
let max_read = MAX_FRAME - 3 - padding_size;
let n = upstream_reader
.read(&mut buf[3..(3 + max_read)])
.await
.or_err(ReadError, "while reading upstream")?;
if n == 0 {
v2::write_body(&mut resp_writer, Bytes::new(), true, None).await?;
return Ok::<(), Box<Error>>(());
}
buf[0] = (n / 256) as u8;
buf[1] = (n % 256) as u8;
buf[2] = padding_size as u8;
for i in 0..padding_size {
buf[3 + n + i] = 0;
}
let total = 3 + n + padding_size;
v2::write_body(
&mut resp_writer,
Bytes::copy_from_slice(&buf[..total]),
false,
None,
)
.await?;
frames_left -= 1;
} else {
let n = upstream_reader
.read(&mut buf)
.await
.or_err(ReadError, "while reading upstream")?;
if n == 0 {
v2::write_body(&mut resp_writer, Bytes::new(), true, None).await?;
return Ok::<(), Box<Error>>(());
}
v2::write_body(
&mut resp_writer,
Bytes::copy_from_slice(&buf[..n]),
false,
None,
)
.await?;
}
}
};
let result = tokio::try_join!(client_to_upstream, upstream_to_client);
match result {
Ok(_) => {
log::info!("CONNECT closed from {} to {}:{}", client_addr, host, port);
Ok(false)
}
Err(err) => {
log::info!(
"CONNECT closed with error from {} to {}:{}: {:?}",
client_addr,
host,
port,
err
);
Err(err)
}
}
}
}
struct PaddingDecoder {
frames_left: usize,
state: DecodeState,
header_buf: [u8; 3],
header_filled: usize,
payload_remaining: usize,
padding_remaining: usize,
}
enum DecodeState {
Header,
Payload,
Padding,
PassThrough,
}
impl PaddingDecoder {
fn new(enabled: bool) -> Self {
let frames_left = if enabled { NUM_FIRST_PADDINGS } else { 0 };
let state = if enabled {
DecodeState::Header
} else {
DecodeState::PassThrough
};
PaddingDecoder {
frames_left,
state,
header_buf: [0; 3],
header_filled: 0,
payload_remaining: 0,
padding_remaining: 0,
}
}
async fn write_unpadded<W>(&mut self, mut input: &[u8], writer: &mut W) -> Result<()>
where
W: AsyncWrite + Unpin,
{
while !input.is_empty() {
if self.frames_left == 0 {
writer
.write_all(input)
.await
.or_err(WriteError, "while writing upstream")?;
return Ok(());
}
match self.state {
DecodeState::Header => {
let need = 3 - self.header_filled;
let take = need.min(input.len());
self.header_buf[self.header_filled..self.header_filled + take]
.copy_from_slice(&input[..take]);
self.header_filled += take;
input = &input[take..];
if self.header_filled == 3 {
self.payload_remaining =
(self.header_buf[0] as usize) * 256 + self.header_buf[1] as usize;
self.padding_remaining = self.header_buf[2] as usize;
self.state = DecodeState::Payload;
}
}
DecodeState::Payload => {
let take = self.payload_remaining.min(input.len());
if take > 0 {
writer
.write_all(&input[..take])
.await
.or_err(WriteError, "while writing upstream")?;
}
self.payload_remaining -= take;
input = &input[take..];
if self.payload_remaining == 0 {
self.state = DecodeState::Padding;
}
}
DecodeState::Padding => {
let take = self.padding_remaining.min(input.len());
self.padding_remaining -= take;
input = &input[take..];
if self.padding_remaining == 0 {
self.frames_left -= 1;
self.header_filled = 0;
self.state = if self.frames_left == 0 {
DecodeState::PassThrough
} else {
DecodeState::Header
};
}
}
DecodeState::PassThrough => {
writer
.write_all(input)
.await
.or_err(WriteError, "while writing upstream")?;
return Ok(());
}
}
}
Ok(())
}
}
fn main() {
let _ = env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info"))
.try_init();
let port = env::var("HKJ_PORT").unwrap_or_else(|_| "443".to_string());
let username = env::var("HKJ_USERNAME").expect("HKJ_USERNAME is required");
let password = env::var("HKJ_PASSWORD").expect("HKJ_PASSWORD is required");
let resist_407 = env::var("HKJ_PROBE_RESISTANCE")
.map(|v| v != "0")
.unwrap_or(true);
let cert_path = env::var("HKJ_CERT_PATH").expect("HKJ_CERT_PATH is required");
let key_path = env::var("HKJ_KEY_PATH").expect("HKJ_KEY_PATH is required");
let mut server = Server::new(Some(Opt::parse_args())).unwrap();
server.bootstrap();
let mut options = HttpServerOptions::default();
options.h2c = false;
options.allow_connect_method_proxying = true;
let proxy = ForwardProxy {
username,
password,
resist_407,
server_options: options,
};
let mut service = Service::new("H2 Forward Proxy".to_string(), proxy);
let mut tls_settings =
TlsSettings::intermediate(&cert_path, &key_path).expect("invalid TLS cert/key");
tls_settings.enable_h2();
service.add_tls_with_settings(&format!("0.0.0.0:{port}"), None, tls_settings);
server.add_service(service);
log::info!("HTTP/2 forward proxy listening on https://0.0.0.0:{port}");
server.run_forever();
}