Newer
Older
hkj / src / main.rs
@Master-Hash Master-Hash 23 days ago 16 KB 差不多得了
use async_trait::async_trait;
use base64::Engine as _;
use base64::engine::general_purpose;
use bytes::Bytes;
use http::Method;
use pingora::apps::HttpServerOptions;
use pingora::http::{RequestHeader, ResponseHeader};
use pingora::listeners::tls::TlsSettings;
use pingora::prelude::*;
use pingora::protocols::http::ServerSession as DownstreamSession;
use pingora::protocols::http::v2;
use pingora::proxy::{ProxyHttp, ProxyServiceBuilder, Session as ProxySession};
use rand::rngs::StdRng;
use rand::{Rng, RngExt};
use std::env;
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;

const HELLO_BODY: &str = "hello, world\n";
const AUTH_FAIL_BODY: &str = "proxy auth required\n";
const HTTP2_ONLY_BODY: &str = "http2 required\n";
const BAD_REQUEST_BODY: &str = "bad request\n";
const MAX_FRAME: usize = 65536;
const NUM_FIRST_PADDINGS: usize = 8;

struct ForwardProxy {
    username: String,
    password: String,
    resist_407: bool,
}

#[async_trait]
impl ProxyHttp for ForwardProxy {
    type CTX = ();

    fn new_ctx(&self) -> Self::CTX {}

    async fn request_filter(&self, session: &mut ProxySession, _ctx: &mut Self::CTX) -> Result<bool>
    where
        Self::CTX: Send + Sync,
    {
        if !session.is_http2() {
            self.respond_text(session, 505, HTTP2_ONLY_BODY, &[])
                .await?;
            return Ok(true);
        }

        if !self.check_auth(session.req_header()) {
            if self.resist_407 {
                // return Ok(false);
                // TODO: make this response look like a normal page
                self.respond_text(session, 200, HELLO_BODY, &[]).await?;
                return Ok(true);
            }

            self.respond_text(
                session,
                407,
                AUTH_FAIL_BODY,
                &[("Proxy-Authenticate", "Basic realm=\"HKJ Secure Web Proxy\"")],
            )
            .await?;
            return Ok(true);
        }

        if session.req_header().method != Method::CONNECT {
            self.respond_text(session, 200, HELLO_BODY, &[]).await?;
            return Ok(true);
        }

        self.handle_connect(session).await
    }

    async fn upstream_peer(
        &self,
        _session: &mut ProxySession,
        _ctx: &mut Self::CTX,
    ) -> Result<Box<HttpPeer>> {
        // unimplemented!()
        Err(Error::explain(InvalidHTTPHeader, "upstream not supported"))
    }

    // async fn upstream_request_filter(
    //     &self,
    //     _session: &mut ProxySession,
    //     _upstream_request: &mut RequestHeader,
    //     _ctx: &mut Self::CTX,
    // ) -> Result<()>
    // where
    //     Self::CTX: Send + Sync,
    // {
    // let Some(target) = ctx.target.as_ref() else {
    //     return Ok(());
    // };

    // upstream_request.remove_header("Proxy-Authorization");
    // upstream_request.remove_header("Proxy-Connection");

    // upstream_request
    //     .insert_header("Host", target.host_header.as_str())
    //     .or_err(InvalidHTTPHeader, "invalid Host header")?;

    // upstream_request.set_raw_path(target.path_and_query.as_bytes())?;
    // Ok(())
    // }
}

impl ForwardProxy {
    fn check_auth(&self, req: &RequestHeader) -> bool {
        let Some(value) = req.headers.get("Proxy-Authorization") else {
            return false;
        };
        let Ok(raw) = value.to_str() else {
            return false;
        };

        let mut parts = raw.splitn(2, ' ');
        let scheme = parts.next().unwrap_or("");
        let token = parts.next().unwrap_or("");
        if !scheme.eq_ignore_ascii_case("Basic") {
            return false;
        }

        let Ok(decoded) = general_purpose::STANDARD.decode(token) else {
            return false;
        };
        let Ok(decoded) = std::str::from_utf8(&decoded) else {
            return false;
        };
        decoded == format!("{}:{}", self.username, self.password)
    }

    async fn respond_text(
        &self,
        session: &mut ProxySession,
        status: u16,
        body: &str,
        extra_headers: &[(&'static str, &'static str)],
    ) -> Result<()> {
        let mut resp = ResponseHeader::build(status, None)?;
        resp.insert_header("Content-Type", "text/plain; charset=utf-8")?;
        resp.insert_header("Content-Length", body.len().to_string())?;
        for (name, value) in extra_headers {
            resp.insert_header(*name, *value)?;
        }

        session.set_keepalive(None);
        session.write_response_header(Box::new(resp), false).await?;
        let body_bytes = Bytes::copy_from_slice(body.as_bytes());
        session.write_response_body(Some(body_bytes), true).await?;
        Ok(())
    }

    fn parse_connect_target(&self, req: &RequestHeader) -> Result<(String, u16)> {
        let authority = req
            .uri
            .authority()
            // .map(|a| a.as_str().to_string())
            // .or_else(|| {
            //     let path = req.uri.path();
            //     if !path.is_empty() && !path.starts_with('/') {
            //         Some(path.to_string())
            //     } else {
            //         None
            //     }
            // })
            // .or_else(|| {
            //     req.headers
            //         .get("Host")
            //         .and_then(|v| v.to_str().ok())
            //         .map(|s| s.to_string())
            // })
            .ok_or_else(|| Error::explain(InvalidHTTPHeader, "missing authority"))?;

        // let authority: Authority = authority
        //     .parse()
        //     .or_err(InvalidHTTPHeader, "invalid authority")?;

        let host = authority.host().to_string();
        let port = authority
            .port_u16()
            .ok_or_else(|| Error::explain(InvalidHTTPHeader, "missing port"))?;
        Ok((host, port))
    }

    fn build_padding_header(rng: &mut impl Rng) -> String {
        const PAD_CHARS: &[u8] = b"!#$()+<>?@[]^`{}";

        let padding_len = rng.random_range(30..62);
        let mut padding = vec![b'~'; padding_len];
        let mut bits: u64 = rng.random();

        for slot in padding.iter_mut().take(16) {
            *slot = PAD_CHARS[(bits & 15) as usize];
            bits >>= 4;
        }

        String::from_utf8(padding).unwrap_or_else(|_| "~".repeat(padding_len))
    }

    async fn handle_connect(&self, session: &mut ProxySession) -> Result<bool> {
        let client_addr = session
            .as_downstream()
            .client_addr()
            .map(|addr| addr.to_string())
            .unwrap_or_else(|| "<unknown>".to_string());
        let (host, port) = match self.parse_connect_target(session.req_header()) {
            Ok(v) => v,
            Err(_) => {
                self.respond_text(session, 400, BAD_REQUEST_BODY, &[])
                    .await?;
                return Ok(true);
            }
        };

        let padding_enabled = session.get_header("Padding").is_some();
        let mut header_rng: StdRng = rand::make_rng();
        let padding_header = Self::build_padding_header(&mut header_rng);

        let mut resp = ResponseHeader::build(200, None)?;
        resp.insert_header("Padding", padding_header)?;
        session.write_response_header(Box::new(resp), false).await?;

        let upstream = match TcpStream::connect((host.as_str(), port)).await {
            Ok(stream) => stream,
            Err(_) => {
                session.shutdown().await;
                return Ok(true);
            }
        };

        log::info!(
            "CONNECT established from {} to {}:{}",
            client_addr,
            host,
            port
        );

        let downstream = session.as_downstream_mut();
        let DownstreamSession::H2(h2) = downstream else {
            self.respond_text(session, 505, HTTP2_ONLY_BODY, &[])
                .await?;
            return Ok(true);
        };

        let mut resp_writer = h2
            .take_response_body_writer()
            .ok_or_else(|| Error::explain(H2Error, "missing response body writer"))?;

        let (mut upstream_reader, mut upstream_writer) = upstream.into_split();

        let client_to_upstream = async {
            let mut decoder = PaddingDecoder::new(padding_enabled);
            loop {
                match h2.read_body_bytes().await? {
                    Some(chunk) => {
                        decoder
                            .write_unpadded(chunk.as_ref(), &mut upstream_writer)
                            .await?;
                    }
                    None => {
                        upstream_writer
                            .shutdown()
                            .await
                            .or_err(WriteError, "while shutting down upstream writer")?;
                        return Ok::<(), Box<Error>>(());
                    }
                }
            }
        };

        let upstream_to_client = async {
            let mut rng: StdRng = rand::make_rng();
            let mut buf = vec![0u8; MAX_FRAME];
            let mut frames_left = if padding_enabled {
                NUM_FIRST_PADDINGS
            } else {
                0
            };

            loop {
                if frames_left > 0 {
                    let padding_size: usize = rng.random_range(0..=255);
                    let max_read = MAX_FRAME - 3 - padding_size;
                    let n = upstream_reader
                        .read(&mut buf[3..(3 + max_read)])
                        .await
                        .or_err(ReadError, "while reading upstream")?;

                    if n == 0 {
                        v2::write_body(&mut resp_writer, Bytes::new(), true, None).await?;
                        return Ok::<(), Box<Error>>(());
                    }

                    buf[0] = (n / 256) as u8;
                    buf[1] = (n % 256) as u8;
                    buf[2] = padding_size as u8;
                    for i in 0..padding_size {
                        buf[3 + n + i] = 0;
                    }

                    let total = 3 + n + padding_size;
                    v2::write_body(
                        &mut resp_writer,
                        Bytes::copy_from_slice(&buf[..total]),
                        false,
                        None,
                    )
                    .await?;
                    frames_left -= 1;
                } else {
                    let n = upstream_reader
                        .read(&mut buf)
                        .await
                        .or_err(ReadError, "while reading upstream")?;

                    if n == 0 {
                        v2::write_body(&mut resp_writer, Bytes::new(), true, None).await?;
                        return Ok::<(), Box<Error>>(());
                    }

                    v2::write_body(
                        &mut resp_writer,
                        Bytes::copy_from_slice(&buf[..n]),
                        false,
                        None,
                    )
                    .await?;
                }
            }
        };

        let result = tokio::try_join!(client_to_upstream, upstream_to_client);
        match result {
            Ok(_) => {
                log::info!("CONNECT closed from {} to {}:{}", client_addr, host, port);
                Ok(true)
            }
            Err(err) => {
                log::info!(
                    "CONNECT closed with error from {} to {}:{}: {:?}",
                    client_addr,
                    host,
                    port,
                    err
                );
                Err(err)
            }
        }
    }
}

struct PaddingDecoder {
    frames_left: usize,
    state: DecodeState,
    header_buf: [u8; 3],
    header_filled: usize,
    payload_remaining: usize,
    padding_remaining: usize,
}

enum DecodeState {
    Header,
    Payload,
    Padding,
    PassThrough,
}

impl PaddingDecoder {
    fn new(enabled: bool) -> Self {
        let frames_left = if enabled { NUM_FIRST_PADDINGS } else { 0 };
        let state = if enabled {
            DecodeState::Header
        } else {
            DecodeState::PassThrough
        };
        PaddingDecoder {
            frames_left,
            state,
            header_buf: [0; 3],
            header_filled: 0,
            payload_remaining: 0,
            padding_remaining: 0,
        }
    }

    async fn write_unpadded<W>(&mut self, mut input: &[u8], writer: &mut W) -> Result<()>
    where
        W: AsyncWrite + Unpin,
    {
        while !input.is_empty() {
            if self.frames_left == 0 {
                writer
                    .write_all(input)
                    .await
                    .or_err(WriteError, "while writing upstream")?;
                return Ok(());
            }

            match self.state {
                DecodeState::Header => {
                    let need = 3 - self.header_filled;
                    let take = need.min(input.len());
                    self.header_buf[self.header_filled..self.header_filled + take]
                        .copy_from_slice(&input[..take]);
                    self.header_filled += take;
                    input = &input[take..];

                    if self.header_filled == 3 {
                        self.payload_remaining =
                            (self.header_buf[0] as usize) * 256 + self.header_buf[1] as usize;
                        self.padding_remaining = self.header_buf[2] as usize;
                        self.state = DecodeState::Payload;
                    }
                }
                DecodeState::Payload => {
                    let take = self.payload_remaining.min(input.len());
                    if take > 0 {
                        writer
                            .write_all(&input[..take])
                            .await
                            .or_err(WriteError, "while writing upstream")?;
                    }
                    self.payload_remaining -= take;
                    input = &input[take..];
                    if self.payload_remaining == 0 {
                        self.state = DecodeState::Padding;
                    }
                }
                DecodeState::Padding => {
                    let take = self.padding_remaining.min(input.len());
                    self.padding_remaining -= take;
                    input = &input[take..];
                    if self.padding_remaining == 0 {
                        self.frames_left -= 1;
                        self.header_filled = 0;
                        self.state = if self.frames_left == 0 {
                            DecodeState::PassThrough
                        } else {
                            DecodeState::Header
                        };
                    }
                }
                DecodeState::PassThrough => {
                    writer
                        .write_all(input)
                        .await
                        .or_err(WriteError, "while writing upstream")?;
                    return Ok(());
                }
            }
        }

        Ok(())
    }
}

fn main() {
    let _ = env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info"))
        .try_init();

    let port = env::var("HKJ_PORT").unwrap_or_else(|_| "443".to_string());
    let username = env::var("HKJ_USERNAME").expect("HKJ_USERNAME is required");
    let password = env::var("HKJ_PASSWORD").expect("HKJ_PASSWORD is required");
    let resist_407 = env::var("HKJ_PROBE_RESISTANCE")
        .map(|v| v != "0")
        .unwrap_or(true);
    let cert_path = env::var("HKJ_CERT_PATH").expect("HKJ_CERT_PATH is required");
    let key_path = env::var("HKJ_KEY_PATH").expect("HKJ_KEY_PATH is required");

    let mut server = Server::new(Some(Opt::parse_args())).unwrap();
    server.bootstrap();

    let proxy = ForwardProxy {
        username,
        password,
        resist_407,
    };

    let mut options = HttpServerOptions::default();
    options.h2c = false;
    options.allow_connect_method_proxying = true;

    let mut service = ProxyServiceBuilder::new(&server.configuration, proxy)
        .name("H2 Forward Proxy")
        .server_options(options)
        .build();

    let mut tls_settings =
        TlsSettings::intermediate(&cert_path, &key_path).expect("invalid TLS cert/key");
    tls_settings.enable_h2();
    service.add_tls_with_settings(&format!("0.0.0.0:{port}"), None, tls_settings);

    server.add_service(service);
    log::info!("HTTP/2 forward proxy listening on https://0.0.0.0:{port}");
    server.run_forever();
}