#!/usr/bin/python3 -u ''' ssh http tunnel version 1.1 (c) 2015-2016,2020 Jan ONDREJ (SAL) This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. Features: - simple script (one file only, no dependency, only basic python required) - works on any OS (Linux, Windows, Android, ...) - automatic client startup (using openssh ProxyCommand) - tunnel any TCP connection over HTTP Server usage: sshtunnel.py server localhost:22 listen_port Client usage (.ssh/config): Host test HostName yourservername ProxyCommand tcpproxy.py client http://server_url/ Client usage (TCP port): sshtunnel.py client http://server_url/ listen_port ssh -p listen_port localhost Android client usage: - install QPython - create start_tunnel.py script with this content: from sshtunnel import * run_client("http://server_url", 2222) - connect with ssh client to login@localhost:2222 Environment variables: SSHPROXY_PPS - max packets per second, default 100 (timeout 0.01s) ''' import sys, os, socket, select, time, threading, random, string from urllib.request import urlopen from urllib.error import HTTPError from socketserver import TCPServer, BaseRequestHandler, ThreadingMixIn from http.server import HTTPServer, BaseHTTPRequestHandler, HTTPStatus BUF = 8192 RETRY_COUNT = 600 TIMEOUT = 1.0/float(os.environ.get("SSHPROXY_PPS", 100)) DISCONNECT_TIMEOUT = TIMEOUT*RETRY_COUNT*10 # seconds DISCONNECT_CODE = 301 # Moved MAX_CONNECTIONS = 100 # encoder and decoder encode = decode = lambda x: x #import urllib;encode, decode = urllib.quote, urllib.unquote #import base64;encode, decode = base64.b64encode, base64.b64decode def randomstring(n=8): return ''.join(random.SystemRandom().choice( string.ascii_letters + string.digits ) for x in range(n)) def debug(*args, **kw): print(*args, **kw, file=sys.stderr) class STDIO(object): def __init__(self): self.stop = False def recv(self, max=BUF): rlist, _, _ = select.select([sys.stdin], [], [], TIMEOUT) if not rlist: return '' data = os.read(sys.stdin.fileno(), max) if data=="": self.stop = True return data def send(self, data): if data: if sys.version_info[0]>2: sys.stdout.buffer.write(data) sys.stdout.flush() else: sys.stdout.write(data) class TCPClient(object): def __init__(self, hostport): self.hostport = hostport self.reset() def reset(self): self.stop = False self.s = socket.socket(socket.AF_INET,socket.SOCK_STREAM) self.s.connect(self.hostport) self.s.settimeout(TIMEOUT) self.bytes_in = 0 self.bytes_out = 0 def send(self, data): self.s.sendall(data) self.bytes_out += len(data) def recv(self, buf=BUF): try: data = self.s.recv(buf) if data: self.bytes_in += len(data) else: if self.stop==False: debug("Bytes transfered: in=%s, out=%s" \ % (self.bytes_in, self.bytes_out)) self.stop = True return data except socket.timeout: return '' except socket.error: self.stop = str(sys.exc_info()[1]) class TCPServer(object): def __init__(self, port): self.stop = False self.hostport = ("", port) self.s = socket.socket(socket.AF_INET,socket.SOCK_STREAM) self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.s.bind(self.hostport) self.s.listen(0) self.c, self.addr = self.s.accept() def send(self, data): #debug("S:", repr(data)) self.c.sendall(data) def recv(self, buf=BUF): try: self.c.settimeout(1) data = self.c.recv(buf) if data: #debug("C:", repr(data)) return data else: # no more data in socket self.stop = True except socket.timeout: pass return '' class HTTPServer(ThreadingMixIn, HTTPServer): pass class HTTPHandler(BaseHTTPRequestHandler): retry_count = RETRY_COUNT clients = {} connections = {} def send_no_cache(self): self.send_header("Cache-Control", "no-cache, no-store, must-revalidate") self.send_header("Pragma", "no-cache") self.send_header("Expires", "0") def clean_connections(self): disconnect_time = time.time()-DISCONNECT_TIMEOUT for id, last in list(self.connections.items()): if lastMAX_CONNECTIONS: # remove oldest oldest_id = sorted(self.connections, key=self.connections.get)[0] debug('Max connections reached, removing connection "%s"' \ % oldest_id) self.connections.pop(oldest_id) self.clients.pop(oldest_id) def do_GET(self): id = self.path.lstrip("/").split("/")[0] data = "".encode() # convert to binary self.clean_connections() if id=='connect': data = randomstring() self.clients[data] = TCPClient(self.client_addr) self.connections[data] = time.time() if sys.version_info[0]>2: # convert to bytes for python3 data = bytes(data, "ascii") self.send_response(200) elif id in self.clients: self.connections[id] = time.time() for cntr in range(self.retry_count): try: data = self.clients[id].recv() if data: break except socket.timeout: pass except socket.error: self.clients[id].stop = str(sys.exc_info()[1]) break data = encode(data) if self.clients[id].stop: self.send_response(DISCONNECT_CODE, str(self.clients[id].stop)) #debug("Server disconnect") else: self.send_response(200) else: self.send_response(404) # Not found self.send_header("Content-length", len(data)) self.send_no_cache() self.end_headers() if data: self.wfile.write(data) def do_POST(self): id = self.path.lstrip("/").split("/")[0] data_len = int(self.headers['content-length']) if data_len>0: data = decode(self.rfile.read(data_len)) if id in self.clients: self.clients[id].send(data) self.send_response(200) self.send_no_cache() self.end_headers() def handle_one_request(self): try: BaseHTTPRequestHandler.handle_one_request(self) except socket.error: err = sys.exc_info()[1] debug("handle_request:", err) def finish(self): try: BaseHTTPRequestHandler.finish(self) except socket.error: err = sys.exc_info()[1] debug("finish_request", err) def log_request(self, code='-', size='-'): if isinstance(code, HTTPStatus): code = code.value if str(code)!="200": self.log_message('"%s" %s %s', self.requestline, str(code), str(size)) class HTTPClient(object): def __init__(self, url): self.stop = False self.url = url self.data = None # reset connection self.id = urlopen("%s/connect" % self.url).read() if sys.version_info[0]>2: # convert bytes to string self.id = self.id.decode() def send(self, data=""): urlopen( "%s/%s/%d/%8.6f" % (self.url, self.id, len(data), time.time()), encode(data) # POST data ).read() def recv(self, **kw): try: req = urlopen("%s/%s/%8.6f" % (self.url, self.id, time.time()), **kw) if req.code==DISCONNECT_CODE: self.stop = True except socket.timeout: return '' except HTTPError as err: if err.code!=DISCONNECT_CODE: debug("HTTPError: %s [%s/%s]" % (err, self.url, self.id)) self.stop = str(sys.exc_info()[1]) return '' data = req.read() # GET if data: return decode(data) return '' # Single and multi-thread class loop(object): def __init__(self, t1, t2): self.t1 = t1 self.t2 = t2 def __call__(self): while not (self.t1.stop or self.t2.stop): data = self.t1.recv() if data: self.t2.send(data) def multithread(server, client): # start server as separate thread s = threading.Thread(target=loop(server, client)) s.start() # run client loop(client, server)() # stop both server.stop = True client.stop = True s.join() def one_thread(server, client): while True: data = server.recv() if data: client.send(data) data = client.recv() if data is not None: server.send(data) def run_server(hostport, listen_port): host, port = hostport.split(":", 1) port = int(port) debug("serving at port %d, connected to %s:%d" % (listen_port, host, port)) class handler(HTTPHandler): client_addr = (host, port) try: HTTPServer(("", listen_port), handler).serve_forever() except KeyboardInterrupt: pass def run_client(url, listen_port=None): if listen_port==None: server = STDIO() client = HTTPClient(url) try: multithread(server, client) except KeyboardInterrupt: pass else: debug("serving at port %d" % listen_port) try: server = TCPServer(listen_port) debug("client connected", server.addr) client = HTTPClient(url) except KeyboardInterrupt: pass multithread(server, client) debug("client disconnect") if __name__ == "__main__": if len(sys.argv)==4 and sys.argv[1] in ["server", "s"]: run_server(sys.argv[2], int(sys.argv[3])) elif len(sys.argv)==3 and sys.argv[1] in ["client", "c"]: run_client(sys.argv[2].rstrip("/")) elif len(sys.argv)==4 and sys.argv[1] in ["client", "c"]: run_client(sys.argv[2].rstrip("/"), int(sys.argv[3])) elif len(sys.argv)==3 and sys.argv[1] in ["test", "t"]: client = HTTPClient(sys.argv[2].rstrip("/")) for i in range(5): data = client.recv(timeout=1) if data: sys.stderr.write(data) client.send("TEST\n") else: break else: print(__doc__.lstrip()) sys.exit() os._exit(0)