use async_trait::async_trait;
use base64::Engine as _;
use base64::engine::general_purpose;
use bytes::Bytes;
use http::Method;
use pingora::apps::HttpServerOptions;
use pingora::http::{RequestHeader, ResponseHeader};
use pingora::listeners::tls::TlsSettings;
use pingora::prelude::*;
use pingora::protocols::http::ServerSession as DownstreamSession;
use pingora::protocols::http::v2;
use pingora::proxy::{ProxyHttp, ProxyServiceBuilder, Session as ProxySession};
use rand::rngs::StdRng;
use rand::{Rng, RngExt};
use std::env;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
const HELLO_BODY: &str = "hello, world\n";
const AUTH_FAIL_BODY: &str = "proxy auth required\n";
const HTTP2_ONLY_BODY: &str = "http2 required\n";
const BAD_REQUEST_BODY: &str = "bad request\n";
const MAX_FRAME: usize = 65536;
const NUM_FIRST_PADDINGS: usize = 8;
struct ForwardProxy {
username: String,
password: String,
resist_407: bool,
}
#[async_trait]
impl ProxyHttp for ForwardProxy {
type CTX = ();
fn new_ctx(&self) -> Self::CTX {}
async fn request_filter(&self, session: &mut ProxySession, _ctx: &mut Self::CTX) -> Result<bool>
where
Self::CTX: Send + Sync,
{
// if !session.is_http2() {
// self.respond_text(session, 505, HTTP2_ONLY_BODY, &[])
// .await?;
// return Ok(true);
// }
if session.req_header().method != Method::CONNECT {
self.respond_text(session, 200, HELLO_BODY, &[]).await?;
return Ok(true);
}
if !self.check_auth(session.req_header()) {
if self.resist_407 {
// return Ok(false);
// TODO: make this response look like a normal page
self.respond_text(session, 400, BAD_REQUEST_BODY, &[])
.await?;
return Ok(true);
}
self.respond_text(
session,
407,
AUTH_FAIL_BODY,
&[("Proxy-Authenticate", "Basic realm=\"HKJ Secure Web Proxy\"")],
)
.await?;
return Ok(true);
}
self.handle_connect(session).await
}
async fn upstream_peer(
&self,
_session: &mut ProxySession,
_ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
// unimplemented!()
Err(Error::explain(InvalidHTTPHeader, "upstream not supported"))
}
// async fn upstream_request_filter(
// &self,
// _session: &mut ProxySession,
// _upstream_request: &mut RequestHeader,
// _ctx: &mut Self::CTX,
// ) -> Result<()>
// where
// Self::CTX: Send + Sync,
// {
// let Some(target) = ctx.target.as_ref() else {
// return Ok(());
// };
// upstream_request.remove_header("Proxy-Authorization");
// upstream_request.remove_header("Proxy-Connection");
// upstream_request
// .insert_header("Host", target.host_header.as_str())
// .or_err(InvalidHTTPHeader, "invalid Host header")?;
// upstream_request.set_raw_path(target.path_and_query.as_bytes())?;
// Ok(())
// }
}
impl ForwardProxy {
fn check_auth(&self, req: &RequestHeader) -> bool {
let Some(value) = req.headers.get("Proxy-Authorization") else {
return false;
};
let Ok(raw) = value.to_str() else {
return false;
};
let mut parts = raw.splitn(2, ' ');
let scheme = parts.next().unwrap_or("");
let token = parts.next().unwrap_or("");
if !scheme.eq_ignore_ascii_case("Basic") {
return false;
}
let Ok(decoded) = general_purpose::STANDARD.decode(token) else {
return false;
};
let Ok(decoded) = std::str::from_utf8(&decoded) else {
return false;
};
decoded == format!("{}:{}", self.username, self.password)
}
async fn respond_text(
&self,
session: &mut ProxySession,
status: u16,
body: &str,
extra_headers: &[(&'static str, &'static str)],
) -> Result<()> {
let mut resp = ResponseHeader::build(status, None)?;
resp.insert_header("Content-Type", "text/plain; charset=utf-8")?;
resp.insert_header("Content-Length", body.len().to_string())?;
for (name, value) in extra_headers {
resp.insert_header(*name, *value)?;
}
session.set_keepalive(None);
session.write_response_header(Box::new(resp), false).await?;
let body_bytes = Bytes::copy_from_slice(body.as_bytes());
session.write_response_body(Some(body_bytes), true).await?;
Ok(())
}
fn parse_connect_target(&self, req: &RequestHeader) -> Result<(String, u16)> {
let authority = req
.uri
.authority()
// .map(|a| a.as_str().to_string())
// .or_else(|| {
// let path = req.uri.path();
// if !path.is_empty() && !path.starts_with('/') {
// Some(path.to_string())
// } else {
// None
// }
// })
// .or_else(|| {
// req.headers
// .get("Host")
// .and_then(|v| v.to_str().ok())
// .map(|s| s.to_string())
// })
.ok_or_else(|| Error::explain(InvalidHTTPHeader, "missing authority"))?;
// let authority: Authority = authority
// .parse()
// .or_err(InvalidHTTPHeader, "invalid authority")?;
let host = authority.host().to_string();
let port = authority
.port_u16()
.ok_or_else(|| Error::explain(InvalidHTTPHeader, "missing port"))?;
Ok((host, port))
}
fn build_padding_header(rng: &mut impl Rng) -> String {
const PAD_CHARS: &[u8] = b"!#$()+<>?@[]^`{}";
let padding_len = rng.random_range(30..62);
let mut padding = vec![b'~'; padding_len];
let mut bits: u64 = rng.random();
for slot in padding.iter_mut().take(16) {
*slot = PAD_CHARS[(bits & 15) as usize];
bits >>= 4;
}
String::from_utf8(padding).unwrap_or_else(|_| "~".repeat(padding_len))
}
async fn handle_connect(&self, session: &mut ProxySession) -> Result<bool> {
let client_addr = session
.as_downstream()
.client_addr()
.map(|addr| addr.to_string())
.unwrap_or_else(|| "<unknown>".to_string());
let (host, port) = match self.parse_connect_target(session.req_header()) {
Ok(v) => v,
Err(_) => {
self.respond_text(session, 400, BAD_REQUEST_BODY, &[])
.await?;
return Ok(true);
}
};
let padding_enabled = session.get_header("Padding").is_some();
let mut header_rng: StdRng = rand::make_rng();
let padding_header = Self::build_padding_header(&mut header_rng);
let mut resp = ResponseHeader::build(200, None)?;
resp.insert_header("Padding", padding_header)?;
session.write_response_header(Box::new(resp), false).await?;
let upstream = match TcpStream::connect((host.as_str(), port)).await {
Ok(stream) => stream,
Err(_) => {
session.shutdown().await;
return Ok(true);
}
};
log::info!(
"CONNECT established from {} to {}:{}",
client_addr,
host,
port
);
let downstream = session.as_downstream_mut();
let DownstreamSession::H2(h2) = downstream else {
self.respond_text(session, 505, HTTP2_ONLY_BODY, &[])
.await?;
return Ok(true);
};
let mut resp_writer = h2
.take_response_body_writer()
.ok_or_else(|| Error::explain(H2Error, "missing response body writer"))?;
let (mut upstream_reader, mut upstream_writer) = upstream.into_split();
let client_to_upstream = async {
let mut decoder = PaddingDecoder::new(padding_enabled);
loop {
match h2.read_body_bytes().await? {
Some(chunk) => {
decoder
.write_unpadded(chunk.as_ref(), &mut upstream_writer)
.await?;
}
None => {
upstream_writer
.shutdown()
.await
.or_err(WriteError, "while shutting down upstream writer")?;
return Ok::<(), Box<Error>>(());
}
}
}
};
let upstream_to_client = async {
let mut rng: StdRng = rand::make_rng();
let mut buf = vec![0u8; MAX_FRAME];
let mut frames_left = if padding_enabled {
NUM_FIRST_PADDINGS
} else {
0
};
loop {
if frames_left > 0 {
let padding_size: usize = rng.random_range(0..=255);
let max_read = MAX_FRAME - 3 - padding_size;
let n = upstream_reader
.read(&mut buf[3..(3 + max_read)])
.await
.or_err(ReadError, "while reading upstream")?;
if n == 0 {
v2::write_body(&mut resp_writer, Bytes::new(), true, None).await?;
return Ok::<(), Box<Error>>(());
}
buf[0] = (n / 256) as u8;
buf[1] = (n % 256) as u8;
buf[2] = padding_size as u8;
for i in 0..padding_size {
buf[3 + n + i] = 0;
}
let total = 3 + n + padding_size;
v2::write_body(
&mut resp_writer,
Bytes::copy_from_slice(&buf[..total]),
false,
None,
)
.await?;
frames_left -= 1;
} else {
let n = upstream_reader
.read(&mut buf)
.await
.or_err(ReadError, "while reading upstream")?;
if n == 0 {
v2::write_body(&mut resp_writer, Bytes::new(), true, None).await?;
return Ok::<(), Box<Error>>(());
}
v2::write_body(
&mut resp_writer,
Bytes::copy_from_slice(&buf[..n]),
false,
None,
)
.await?;
}
}
};
let result = tokio::try_join!(client_to_upstream, upstream_to_client);
match result {
Ok(_) => {
log::info!("CONNECT closed from {} to {}:{}", client_addr, host, port);
Ok(true)
}
Err(err) => {
log::info!(
"CONNECT closed with error from {} to {}:{}: {:?}",
client_addr,
host,
port,
err
);
Err(err)
}
}
}
}
struct PaddingDecoder {
frames_left: usize,
state: DecodeState,
header_buf: [u8; 3],
header_filled: usize,
payload_remaining: usize,
padding_remaining: usize,
}
enum DecodeState {
Header,
Payload,
Padding,
PassThrough,
}
impl PaddingDecoder {
fn new(enabled: bool) -> Self {
let frames_left = if enabled { NUM_FIRST_PADDINGS } else { 0 };
let state = if enabled {
DecodeState::Header
} else {
DecodeState::PassThrough
};
PaddingDecoder {
frames_left,
state,
header_buf: [0; 3],
header_filled: 0,
payload_remaining: 0,
padding_remaining: 0,
}
}
async fn write_unpadded<W>(&mut self, mut input: &[u8], writer: &mut W) -> Result<()>
where
W: AsyncWrite + Unpin,
{
while !input.is_empty() {
if self.frames_left == 0 {
writer
.write_all(input)
.await
.or_err(WriteError, "while writing upstream")?;
return Ok(());
}
match self.state {
DecodeState::Header => {
let need = 3 - self.header_filled;
let take = need.min(input.len());
self.header_buf[self.header_filled..self.header_filled + take]
.copy_from_slice(&input[..take]);
self.header_filled += take;
input = &input[take..];
if self.header_filled == 3 {
self.payload_remaining =
(self.header_buf[0] as usize) * 256 + self.header_buf[1] as usize;
self.padding_remaining = self.header_buf[2] as usize;
self.state = DecodeState::Payload;
}
}
DecodeState::Payload => {
let take = self.payload_remaining.min(input.len());
if take > 0 {
writer
.write_all(&input[..take])
.await
.or_err(WriteError, "while writing upstream")?;
}
self.payload_remaining -= take;
input = &input[take..];
if self.payload_remaining == 0 {
self.state = DecodeState::Padding;
}
}
DecodeState::Padding => {
let take = self.padding_remaining.min(input.len());
self.padding_remaining -= take;
input = &input[take..];
if self.padding_remaining == 0 {
self.frames_left -= 1;
self.header_filled = 0;
self.state = if self.frames_left == 0 {
DecodeState::PassThrough
} else {
DecodeState::Header
};
}
}
DecodeState::PassThrough => {
writer
.write_all(input)
.await
.or_err(WriteError, "while writing upstream")?;
return Ok(());
}
}
}
Ok(())
}
}
fn main() {
let _ = env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info"))
.try_init();
let port = env::var("HKJ_PORT").unwrap_or_else(|_| "443".to_string());
let username = env::var("HKJ_USERNAME").expect("HKJ_USERNAME is required");
let password = env::var("HKJ_PASSWORD").expect("HKJ_PASSWORD is required");
let resist_407 = env::var("HKJ_PROBE_RESISTANCE")
.map(|v| v != "0")
.unwrap_or(true);
let cert_path = env::var("HKJ_CERT_PATH").expect("HKJ_CERT_PATH is required");
let key_path = env::var("HKJ_KEY_PATH").expect("HKJ_KEY_PATH is required");
let mut server = Server::new(Some(Opt::parse_args())).unwrap();
server.bootstrap();
let proxy = ForwardProxy {
username,
password,
resist_407,
};
let mut options = HttpServerOptions::default();
options.h2c = false;
options.allow_connect_method_proxying = true;
let mut service = ProxyServiceBuilder::new(&server.configuration, proxy)
.name("H2 Forward Proxy")
.server_options(options)
.build();
let mut tls_settings =
TlsSettings::intermediate(&cert_path, &key_path).expect("invalid TLS cert/key");
tls_settings.enable_h2();
service.add_tls_with_settings(&format!("0.0.0.0:{port}"), None, tls_settings);
server.add_service(service);
log::info!("HTTP/2 forward proxy listening on https://0.0.0.0:{port}");
server.run_forever();
}