Skip to content

Commit

Permalink
Merge branch 'gliderlabs:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
rawoul authored Dec 13, 2024
2 parents 81194fc + d137aad commit 79230d5
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 34 deletions.
3 changes: 1 addition & 2 deletions _examples/ssh-sftpserver/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"fmt"
"io"
"io/ioutil"
"log"

"github.com/gliderlabs/ssh"
Expand All @@ -12,7 +11,7 @@ import (

// SftpHandler handler for SFTP subsystem
func SftpHandler(sess ssh.Session) {
debugStream := ioutil.Discard
debugStream := io.Discard
serverOptions := []sftp.ServerOption{
sftp.WithDebug(debugStream),
}
Expand Down
4 changes: 2 additions & 2 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package ssh

import (
"io"
"io/ioutil"
"net"
"os"
"path"
"sync"

Expand Down Expand Up @@ -36,7 +36,7 @@ func AgentRequested(sess Session) bool {
// NewAgentListener sets up a temporary Unix socket that can be communicated
// to the session environment and used for forwarding connections.
func NewAgentListener() (net.Listener, error) {
dir, err := ioutil.TempDir("", agentTempDir)
dir, err := os.MkdirTemp("", agentTempDir)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions circle.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ jobs:
- run: go get
- run: go test -v -race

build-go-1.13:
build-go-1.20:
docker:
- image: golang:1.13
- image: golang:1.20
working_directory: /go/src/github.com/gliderlabs/ssh
steps:
- checkout
Expand All @@ -23,4 +23,4 @@ workflows:
build:
jobs:
- build-go-latest
- build-go-1.13
- build-go-1.20
18 changes: 16 additions & 2 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,14 @@ type Context interface {
type sshContext struct {
context.Context
*sync.Mutex

values map[interface{}]interface{}
valuesMu sync.Mutex
}

func newContext(srv *Server) (*sshContext, context.CancelFunc) {
innerCtx, cancel := context.WithCancel(context.Background())
ctx := &sshContext{innerCtx, &sync.Mutex{}}
ctx := &sshContext{Context: innerCtx, Mutex: &sync.Mutex{}, values: make(map[interface{}]interface{})}
ctx.SetValue(ContextKeyServer, srv)
perms := &Permissions{&gossh.Permissions{}}
ctx.SetValue(ContextKeyPermissions, perms)
Expand All @@ -119,8 +122,19 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) {
ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr())
}

func (ctx *sshContext) Value(key interface{}) interface{} {
ctx.valuesMu.Lock()
defer ctx.valuesMu.Unlock()
if v, ok := ctx.values[key]; ok {
return v
}
return ctx.Context.Value(key)
}

func (ctx *sshContext) SetValue(key, value interface{}) {
ctx.Context = context.WithValue(ctx.Context, key, value)
ctx.valuesMu.Lock()
defer ctx.valuesMu.Unlock()
ctx.values[key] = value
}

func (ctx *sshContext) User() string {
Expand Down
40 changes: 39 additions & 1 deletion context_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package ssh

import "testing"
import (
"testing"
"time"
)

func TestSetPermissions(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -45,3 +48,38 @@ func TestSetValue(t *testing.T) {
t.Fatal(err)
}
}

func TestSetValueConcurrency(t *testing.T) {
ctx, cancel := newContext(nil)
defer cancel()

go func() {
for { // use a loop to access context.Context functions to make sure they are thread-safe with SetValue
_, _ = ctx.Deadline()
_ = ctx.Err()
_ = ctx.Value("foo")
select {
case <-ctx.Done():
break
default:
}
}
}()
ctx.SetValue("bar", -1) // a context value which never changes
now := time.Now()
var cnt int64
go func() {
for time.Since(now) < 100*time.Millisecond {
cnt++
ctx.SetValue("foo", cnt) // a context value which changes a lot
}
cancel()
}()
<-ctx.Done()
if ctx.Value("foo") != cnt {
t.Fatal("context.Value(foo) doesn't match latest SetValue")
}
if ctx.Value("bar") != -1 {
t.Fatal("context.Value(bar) doesn't match latest SetValue")
}
}
4 changes: 2 additions & 2 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package ssh_test

import (
"io"
"io/ioutil"
"os"

"github.com/gliderlabs/ssh"
)
Expand All @@ -28,7 +28,7 @@ func ExampleNoPty() {
func ExamplePublicKeyAuth() {
ssh.ListenAndServe(":2222", nil,
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
data, _ := ioutil.ReadFile("/path/to/allowed/key.pub")
data, _ := os.ReadFile("/path/to/allowed/key.pub")
allowed, _, _, _, _ := ssh.ParseAuthorizedKey(data)
return ssh.KeysEqual(key, allowed)
}),
Expand Down
7 changes: 4 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module github.com/gliderlabs/ssh

go 1.12
go 1.20

require (
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be
golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d
golang.org/x/term v0.5.0 // indirect
golang.org/x/crypto v0.31.0
)

require golang.org/x/sys v0.28.0 // indirect
18 changes: 5 additions & 13 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d h1:3qF+Z8Hkrw9sOhrFHti9TlB1Hkac1x+DNRkv0XQiFjo=
golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
4 changes: 2 additions & 2 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package ssh

import (
"io/ioutil"
"os"

gossh "golang.org/x/crypto/ssh"
)
Expand All @@ -26,7 +26,7 @@ func PublicKeyAuth(fn PublicKeyHandler) Option {
// from a PEM file at filepath.
func HostKeyFile(filepath string) Option {
return func(srv *Server) error {
pemBytes, err := ioutil.ReadFile(filepath)
pemBytes, err := os.ReadFile(filepath)
if err != nil {
return err
}
Expand Down
13 changes: 13 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ type Server struct {
Handler Handler // handler to invoke, ssh.DefaultHandler if nil
HostSigners []Signer // private keys for the host key, must have at least one
Version string // server version to be sent before the initial handshake
Banner string // server banner

BannerHandler BannerHandler // server banner handler, overrides Banner
KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler
PasswordHandler PasswordHandler // password authentication handler
PublicKeyHandler PublicKeyHandler // public key authentication handler
Expand Down Expand Up @@ -132,6 +134,17 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig {
if srv.Version != "" {
config.ServerVersion = "SSH-2.0-" + srv.Version
}
if srv.Banner != "" {
config.BannerCallback = func(_ gossh.ConnMetadata) string {
return srv.Banner
}
}
if srv.BannerHandler != nil {
config.BannerCallback = func(conn gossh.ConnMetadata) string {
applyConnMetadata(ctx, conn)
return srv.BannerHandler(ctx)
}
}
if srv.PasswordHandler != nil {
config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) {
applyConnMetadata(ctx, conn)
Expand Down
6 changes: 4 additions & 2 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ type Option func(*Server) error
// Handler is a callback for handling established SSH sessions.
type Handler func(Session)

// BannerHandler is a callback for displaying the server banner.
type BannerHandler func(ctx Context) string

// PublicKeyHandler is a callback for performing public key authentication.
type PublicKeyHandler func(ctx Context, key PublicKey) bool

Expand Down Expand Up @@ -115,8 +118,7 @@ func Handle(handler Handler) {

// KeysEqual is constant time compare of the keys to avoid timing attacks.
func KeysEqual(ak, bk PublicKey) bool {

//avoid panic if one of the keys is nil, return false instead
// avoid panic if one of the keys is nil, return false instead
if ak == nil || bk == nil {
return false
}
Expand Down
4 changes: 2 additions & 2 deletions tcpip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package ssh

import (
"bytes"
"io/ioutil"
"io"
"net"
"strconv"
"strings"
Expand Down Expand Up @@ -58,7 +58,7 @@ func TestLocalPortForwardingWorks(t *testing.T) {
if err != nil {
t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err)
}
result, err := ioutil.ReadAll(conn)
result, err := io.ReadAll(conn)
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 79230d5

Please sign in to comment.