diff --git a/Cargo.lock b/Cargo.lock index 0cb9921..7cea2ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2373,9 +2373,9 @@ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "itoa", "memchr", diff --git a/Cargo.toml b/Cargo.toml index 5b77f39..a553901 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,23 +11,25 @@ async-trait = "0.1.89" base64 = "0.22.1" bytes = "1.11.1" -env_logger = { version = "0.11.10", default-features = false, features = [ - "auto-color", - "regex" -] } +env_logger = { + version = "0.11.10", + default-features = false, + features = ["auto-color", "regex"] +} http = "1.4.0" -instant-acme = { version = "0.8.5", default-features = false, features = [ - "aws-lc-rs" -], optional = true } +instant-acme = { + version = "0.8.5", + default-features = false, + features = ["aws-lc-rs"], + optional = true +} log = "0.4.29" -pingora = { version = "0.8.0", features = ["proxy", "rustls"] } +pingora = { version = "0.8.0", features = ["rustls"] } rand = "0.10.1" -tokio = { version = "1.52.3", features = [ - "io-util", - "macros", - "net", - "rt-multi-thread" -] } +tokio = { + version = "1.52.3", + features = ["io-util", "macros", "net", "rt-multi-thread"] +} [profile.release] opt-level = 3 diff --git a/src/main.rs b/src/main.rs index 837b42a..2b2739d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,16 +3,18 @@ use base64::engine::general_purpose; use bytes::Bytes; use http::Method; -use pingora::apps::HttpServerOptions; +use pingora::apps::{HttpPersistentSettings, HttpServerApp, HttpServerOptions, ReusedHttpStream}; use pingora::http::{RequestHeader, ResponseHeader}; use pingora::listeners::tls::TlsSettings; use pingora::prelude::*; use pingora::protocols::http::ServerSession as DownstreamSession; use pingora::protocols::http::v2; -use pingora::proxy::{ProxyHttp, ProxyServiceBuilder, Session as ProxySession}; +use pingora::server::ShutdownWatch; +use pingora::services::listening::Service; use rand::rngs::StdRng; use rand::{Rng, RngExt}; use std::env; +use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; @@ -27,83 +29,72 @@ username: String, password: String, resist_407: bool, + server_options: HttpServerOptions, } #[async_trait] -impl ProxyHttp for ForwardProxy { - type CTX = (); +impl HttpServerApp for ForwardProxy { + async fn process_new_http( + self: &Arc, + mut session: DownstreamSession, + shutdown: &ShutdownWatch, + ) -> Option { + match session.read_request().await { + Ok(true) => {} + Ok(false) => { + return None; + } + Err(err) => { + log::error!("HTTP server fails to read from downstream: {err}"); + return None; + } + } - fn new_ctx(&self) -> Self::CTX {} - - async fn request_filter(&self, session: &mut ProxySession, _ctx: &mut Self::CTX) -> Result - where - Self::CTX: Send + Sync, - { - // if !session.is_http2() { - // self.respond_text(session, 505, HTTP2_ONLY_BODY, &[]) - // .await?; - // return Ok(true); - // } + if *shutdown.borrow() { + session.set_keepalive(None); + } if session.req_header().method != Method::CONNECT { - self.respond_text(session, 200, HELLO_BODY, &[]).await?; - return Ok(true); + if let Err(err) = self.respond_text(&mut session, 200, HELLO_BODY, &[]).await { + log::error!("failed to send hello response: {err}"); + } + return self.finish_session(session).await; } if !self.check_auth(session.req_header()) { - if self.resist_407 { - // return Ok(false); - // TODO: make this response look like a normal page - self.respond_text(session, 400, BAD_REQUEST_BODY, &[]) - .await?; - return Ok(true); - } + let result = if self.resist_407 { + self.respond_text(&mut session, 400, BAD_REQUEST_BODY, &[]) + .await + } else { + self.respond_text( + &mut session, + 407, + AUTH_FAIL_BODY, + &[("Proxy-Authenticate", "Basic realm=\"HKJ Secure Web Proxy\"")], + ) + .await + }; - self.respond_text( - session, - 407, - AUTH_FAIL_BODY, - &[("Proxy-Authenticate", "Basic realm=\"HKJ Secure Web Proxy\"")], - ) - .await?; - return Ok(true); + if let Err(err) = result { + log::error!("failed to send auth response: {err}"); + } + return self.finish_session(session).await; } - self.handle_connect(session).await + match self.handle_connect(&mut session).await { + Ok(should_finish) => { + if should_finish { + return self.finish_session(session).await; + } + None + } + Err(_) => None, + } } - async fn upstream_peer( - &self, - _session: &mut ProxySession, - _ctx: &mut Self::CTX, - ) -> Result> { - // unimplemented!() - Err(Error::explain(InvalidHTTPHeader, "upstream not supported")) + fn server_options(&self) -> Option<&HttpServerOptions> { + Some(&self.server_options) } - - // async fn upstream_request_filter( - // &self, - // _session: &mut ProxySession, - // _upstream_request: &mut RequestHeader, - // _ctx: &mut Self::CTX, - // ) -> Result<()> - // where - // Self::CTX: Send + Sync, - // { - // let Some(target) = ctx.target.as_ref() else { - // return Ok(()); - // }; - - // upstream_request.remove_header("Proxy-Authorization"); - // upstream_request.remove_header("Proxy-Connection"); - - // upstream_request - // .insert_header("Host", target.host_header.as_str()) - // .or_err(InvalidHTTPHeader, "invalid Host header")?; - - // upstream_request.set_raw_path(target.path_and_query.as_bytes())?; - // Ok(()) - // } } impl ForwardProxy { @@ -133,7 +124,7 @@ async fn respond_text( &self, - session: &mut ProxySession, + session: &mut DownstreamSession, status: u16, body: &str, extra_headers: &[(&'static str, &'static str)], @@ -145,13 +136,24 @@ resp.insert_header(*name, *value)?; } - session.set_keepalive(None); - session.write_response_header(Box::new(resp), false).await?; + // session.set_keepalive(None); + session.write_response_header(Box::new(resp)).await?; let body_bytes = Bytes::copy_from_slice(body.as_bytes()); - session.write_response_body(Some(body_bytes), true).await?; + session.write_response_body(body_bytes, true).await?; Ok(()) } + async fn finish_session(&self, session: DownstreamSession) -> Option { + let persistent_settings = HttpPersistentSettings::for_session(&session); + match session.finish().await { + Ok(stream) => stream.map(|s| ReusedHttpStream::new(s, Some(persistent_settings))), + Err(err) => { + log::error!("HTTP server fails to finish the request: {err}"); + None + } + } + } + fn parse_connect_target(&self, req: &RequestHeader) -> Result<(String, u16)> { let authority = req .uri @@ -199,12 +201,7 @@ String::from_utf8(padding).unwrap_or_else(|_| "~".repeat(padding_len)) } - async fn handle_connect(&self, session: &mut ProxySession) -> Result { - let client_addr = session - .as_downstream() - .client_addr() - .map(|addr| addr.to_string()) - .unwrap_or_else(|| "".to_string()); + async fn handle_connect(&self, session: &mut DownstreamSession) -> Result { let (host, port) = match self.parse_connect_target(session.req_header()) { Ok(v) => v, Err(_) => { @@ -214,22 +211,37 @@ } }; + if !matches!(&*session, DownstreamSession::H2(_)) { + self.respond_text(session, 505, HTTP2_ONLY_BODY, &[]) + .await?; + return Ok(true); + } + let padding_enabled = session.get_header("Padding").is_some(); let mut header_rng: StdRng = rand::make_rng(); let padding_header = Self::build_padding_header(&mut header_rng); let mut resp = ResponseHeader::build(200, None)?; resp.insert_header("Padding", padding_header)?; - session.write_response_header(Box::new(resp), false).await?; + session.write_response_header(Box::new(resp)).await?; let upstream = match TcpStream::connect((host.as_str(), port)).await { Ok(stream) => stream, Err(_) => { - session.shutdown().await; + session.write_response_body(Bytes::new(), true).await?; return Ok(true); } }; + let DownstreamSession::H2(h2) = session else { + unreachable!("checked H2 session above"); + }; + + let client_addr = h2 + .client_addr() + .map(|addr| addr.to_string()) + .unwrap_or_else(|| "".to_string()); + log::info!( "CONNECT established from {} to {}:{}", client_addr, @@ -237,13 +249,6 @@ port ); - let downstream = session.as_downstream_mut(); - let DownstreamSession::H2(h2) = downstream else { - self.respond_text(session, 505, HTTP2_ONLY_BODY, &[]) - .await?; - return Ok(true); - }; - let mut resp_writer = h2 .take_response_body_writer() .ok_or_else(|| Error::explain(H2Error, "missing response body writer"))?; @@ -335,7 +340,7 @@ match result { Ok(_) => { log::info!("CONNECT closed from {} to {}:{}", client_addr, host, port); - Ok(true) + Ok(false) } Err(err) => { log::info!( @@ -472,20 +477,18 @@ let mut server = Server::new(Some(Opt::parse_args())).unwrap(); server.bootstrap(); - let proxy = ForwardProxy { - username, - password, - resist_407, - }; - let mut options = HttpServerOptions::default(); options.h2c = false; options.allow_connect_method_proxying = true; - let mut service = ProxyServiceBuilder::new(&server.configuration, proxy) - .name("H2 Forward Proxy") - .server_options(options) - .build(); + let proxy = ForwardProxy { + username, + password, + resist_407, + server_options: options, + }; + + let mut service = Service::new("H2 Forward Proxy".to_string(), proxy); let mut tls_settings = TlsSettings::intermediate(&cert_path, &key_path).expect("invalid TLS cert/key");