diff --git a/client/client_root_validation_test.go b/client/client_root_validation_test.go index 68612fb863..8aa3511d4c 100644 --- a/client/client_root_validation_test.go +++ b/client/client_root_validation_test.go @@ -36,7 +36,10 @@ const signedRSARootTemplate = `{"signed":{"_type":"Root","consistent_snapshot":f // We test this with both an RSA and ECDSA root key func TestValidateRoot(t *testing.T) { logrus.SetLevel(logrus.DebugLevel) - validateRootSuccessfully(t, data.RSAKey) + validateRootSuccessfully(t, data.ECDSAKey) + if !testing.Short() { + validateRootSuccessfully(t, data.RSAKey) + } } func validateRootSuccessfully(t *testing.T, rootType data.KeyAlgorithm) { diff --git a/cryptoservice/crypto_service.go b/cryptoservice/crypto_service.go index 061c0541d4..3a61cb95da 100644 --- a/cryptoservice/crypto_service.go +++ b/cryptoservice/crypto_service.go @@ -68,7 +68,7 @@ func (ccs *CryptoService) Create(role string, algorithm data.KeyAlgorithm) (data // GetKey returns a key by ID func (ccs *CryptoService) GetKey(keyID string) data.PublicKey { - key, err := ccs.keyStore.GetKey(keyID) + key, _, err := ccs.keyStore.GetKey(keyID) if err != nil { return nil } @@ -92,7 +92,7 @@ func (ccs *CryptoService) Sign(keyIDs []string, payload []byte) ([]data.Signatur var privKey data.PrivateKey var err error - privKey, err = ccs.keyStore.GetKey(keyName) + privKey, _, err = ccs.keyStore.GetKey(keyName) if err != nil { // Note that GetKey always fails on InitRepo. // InitRepo gets a signer that doesn't have access to diff --git a/keystoremanager/import_export.go b/keystoremanager/import_export.go index 3c42ab8263..bd87244681 100644 --- a/keystoremanager/import_export.go +++ b/keystoremanager/import_export.go @@ -81,12 +81,7 @@ func (km *KeyStoreManager) ImportRootKey(source io.Reader, keyID string) error { func moveKeys(oldKeyStore, newKeyStore *trustmanager.KeyFileStore) error { // List all files but no symlinks for _, f := range oldKeyStore.ListKeys() { - pemBytes, err := oldKeyStore.GetKey(f) - if err != nil { - return err - } - - alias, err := oldKeyStore.GetKeyAlias(f) + pemBytes, alias, err := oldKeyStore.GetKey(f) if err != nil { return err } @@ -259,12 +254,7 @@ func moveKeysByGUN(oldKeyStore, newKeyStore *trustmanager.KeyFileStore, gun stri continue } - privKey, err := oldKeyStore.GetKey(relKeyPath) - if err != nil { - return err - } - - alias, err := oldKeyStore.GetKeyAlias(relKeyPath) + privKey, alias, err := oldKeyStore.GetKey(relKeyPath) if err != nil { return err } diff --git a/keystoremanager/import_export_test.go b/keystoremanager/import_export_test.go index 47247f95c9..4a62995796 100644 --- a/keystoremanager/import_export_test.go +++ b/keystoremanager/import_export_test.go @@ -85,7 +85,7 @@ func TestImportExportZip(t *testing.T) { // because the passwords were chosen by the newPassphraseRetriever. privKeyList := repo.KeyStoreManager.NonRootKeyStore().ListKeys() for _, privKeyName := range privKeyList { - alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKeyAlias(privKeyName) + _, alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKey(privKeyName) assert.NoError(t, err, "privKey %s has no alias", privKeyName) relKeyPath := filepath.Join("private", "tuf_keys", privKeyName+"_"+alias+".key") @@ -156,7 +156,7 @@ func TestImportExportZip(t *testing.T) { // Look for keys in private. The filenames should match the key IDs // in the repo's private key store. for _, privKeyName := range privKeyList { - alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKeyAlias(privKeyName) + _, alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKey(privKeyName) assert.NoError(t, err, "privKey %s has no alias", privKeyName) relKeyPath := filepath.Join("private", "tuf_keys", privKeyName+"_"+alias+".key") @@ -221,7 +221,7 @@ func TestImportExportGUN(t *testing.T) { // because they were formerly unencrypted. privKeyList := repo.KeyStoreManager.NonRootKeyStore().ListKeys() for _, privKeyName := range privKeyList { - alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKeyAlias(privKeyName) + _, alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKey(privKeyName) if err != nil { t.Fatalf("privKey %s has no alias", privKeyName) } @@ -290,7 +290,7 @@ func TestImportExportGUN(t *testing.T) { // Look for keys in private. The filenames should match the key IDs // in the repo's private key store. for _, privKeyName := range privKeyList { - alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKeyAlias(privKeyName) + _, alias, err := repo.KeyStoreManager.NonRootKeyStore().GetKey(privKeyName) if err != nil { t.Fatalf("privKey %s has no alias", privKeyName) } diff --git a/keystoremanager/keystoremanager.go b/keystoremanager/keystoremanager.go index 6f4c1d6b20..ecc4c1bcec 100644 --- a/keystoremanager/keystoremanager.go +++ b/keystoremanager/keystoremanager.go @@ -173,7 +173,7 @@ func (km *KeyStoreManager) GenRootKey(algorithm string) (string, error) { // GetRootCryptoService retrieves a root key and a cryptoservice to use with it // TODO(mccauley): remove this as its no longer needed once we have key caching in the keystores func (km *KeyStoreManager) GetRootCryptoService(rootKeyID string) (*cryptoservice.UnlockedCryptoService, error) { - privKey, err := km.rootKeyStore.GetKey(rootKeyID) + privKey, _, err := km.rootKeyStore.GetKey(rootKeyID) if err != nil { return nil, fmt.Errorf("could not get decrypted root key with keyID: %s, %v", rootKeyID, err) } diff --git a/trustmanager/keyfilestore.go b/trustmanager/keyfilestore.go index ae941088bd..f92396508b 100644 --- a/trustmanager/keyfilestore.go +++ b/trustmanager/keyfilestore.go @@ -3,6 +3,7 @@ package trustmanager import ( "path/filepath" "strings" + "sync" "errors" "fmt" @@ -20,22 +21,36 @@ type KeyStore interface { LimitedFileStore AddKey(name, alias string, privKey data.PrivateKey) error - GetKey(name string) (data.PrivateKey, error) - GetKeyAlias(name string) (string, error) + GetKey(name string) (data.PrivateKey, string, error) ListKeys() []string RemoveKey(name string) error } +type cachedKey struct { + alias string + key data.PrivateKey +} + +// PassphraseRetriever is a callback function that should retrieve a passphrase +// for a given named key. If it should be treated as new passphrase (e.g. with +// confirmation), createNew will be true. Attempts is passed in so that implementers +// decide how many chances to give to a human, for example. +type PassphraseRetriever func(keyId, alias string, createNew bool, attempts int) (passphrase string, giveup bool, err error) + // KeyFileStore persists and manages private keys on disk type KeyFileStore struct { + sync.Mutex SimpleFileStore - PassphraseRetriever passphrase.Retriever + passphrase.Retriever + cachedKeys map[string]*cachedKey } // KeyMemoryStore manages private keys in memory type KeyMemoryStore struct { + sync.Mutex MemoryFileStore - PassphraseRetriever passphrase.Retriever + passphrase.Retriever + cachedKeys map[string]*cachedKey } // NewKeyFileStore returns a new KeyFileStore creating a private directory to @@ -45,23 +60,25 @@ func NewKeyFileStore(baseDir string, passphraseRetriever passphrase.Retriever) ( if err != nil { return nil, err } + cachedKeys := make(map[string]*cachedKey) - return &KeyFileStore{*fileStore, passphraseRetriever}, nil + return &KeyFileStore{SimpleFileStore: *fileStore, + Retriever: passphraseRetriever, + cachedKeys: cachedKeys}, nil } // AddKey stores the contents of a PEM-encoded private key as a PEM block func (s *KeyFileStore) AddKey(name, alias string, privKey data.PrivateKey) error { - return addKey(s, s.PassphraseRetriever, name, alias, privKey) + s.Lock() + defer s.Unlock() + return addKey(s, s.Retriever, s.cachedKeys, name, alias, privKey) } // GetKey returns the PrivateKey given a KeyID -func (s *KeyFileStore) GetKey(name string) (data.PrivateKey, error) { - return getKey(s, s.PassphraseRetriever, name) -} - -// GetKeyAlias returns the PrivateKey's alias given a KeyID -func (s *KeyFileStore) GetKeyAlias(name string) (string, error) { - return getKeyAlias(s, name) +func (s *KeyFileStore) GetKey(name string) (data.PrivateKey, string, error) { + s.Lock() + defer s.Unlock() + return getKey(s, s.Retriever, s.cachedKeys, name) } // ListKeys returns a list of unique PublicKeys present on the KeyFileStore. @@ -73,29 +90,33 @@ func (s *KeyFileStore) ListKeys() []string { // RemoveKey removes the key from the keyfilestore func (s *KeyFileStore) RemoveKey(name string) error { - return removeKey(s, name) + s.Lock() + defer s.Unlock() + return removeKey(s, s.cachedKeys, name) } // NewKeyMemoryStore returns a new KeyMemoryStore which holds keys in memory func NewKeyMemoryStore(passphraseRetriever passphrase.Retriever) *KeyMemoryStore { memStore := NewMemoryFileStore() + cachedKeys := make(map[string]*cachedKey) - return &KeyMemoryStore{*memStore, passphraseRetriever} + return &KeyMemoryStore{MemoryFileStore: *memStore, + Retriever: passphraseRetriever, + cachedKeys: cachedKeys} } // AddKey stores the contents of a PEM-encoded private key as a PEM block func (s *KeyMemoryStore) AddKey(name, alias string, privKey data.PrivateKey) error { - return addKey(s, s.PassphraseRetriever, name, alias, privKey) + s.Lock() + defer s.Unlock() + return addKey(s, s.Retriever, s.cachedKeys, name, alias, privKey) } // GetKey returns the PrivateKey given a KeyID -func (s *KeyMemoryStore) GetKey(name string) (data.PrivateKey, error) { - return getKey(s, s.PassphraseRetriever, name) -} - -// GetKeyAlias returns the PrivateKey's alias given a KeyID -func (s *KeyMemoryStore) GetKeyAlias(name string) (string, error) { - return getKeyAlias(s, name) +func (s *KeyMemoryStore) GetKey(name string) (data.PrivateKey, string, error) { + s.Lock() + defer s.Unlock() + return getKey(s, s.Retriever, s.cachedKeys, name) } // ListKeys returns a list of unique PublicKeys present on the KeyFileStore. @@ -107,10 +128,12 @@ func (s *KeyMemoryStore) ListKeys() []string { // RemoveKey removes the key from the keystore func (s *KeyMemoryStore) RemoveKey(name string) error { - return removeKey(s, name) + s.Lock() + defer s.Unlock() + return removeKey(s, s.cachedKeys, name) } -func addKey(s LimitedFileStore, passphraseRetriever passphrase.Retriever, name, alias string, privKey data.PrivateKey) error { +func addKey(s LimitedFileStore, passphraseRetriever passphrase.Retriever, cachedKeys map[string]*cachedKey, name, alias string, privKey data.PrivateKey) error { pemPrivKey, err := KeyToPEM(privKey) if err != nil { return err @@ -141,6 +164,7 @@ func addKey(s LimitedFileStore, passphraseRetriever passphrase.Retriever, name, } } + cachedKeys[name] = &cachedKey{alias: alias, key: privKey} return s.Add(name+"_"+alias, pemPrivKey) } @@ -162,15 +186,19 @@ func getKeyAlias(s LimitedFileStore, keyID string) (string, error) { } // GetKey returns the PrivateKey given a KeyID -func getKey(s LimitedFileStore, passphraseRetriever passphrase.Retriever, name string) (data.PrivateKey, error) { +func getKey(s LimitedFileStore, passphraseRetriever passphrase.Retriever, cachedKeys map[string]*cachedKey, name string) (data.PrivateKey, string, error) { + cachedKeyEntry, ok := cachedKeys[name] + if ok { + return cachedKeyEntry.key, cachedKeyEntry.alias, nil + } keyAlias, err := getKeyAlias(s, name) if err != nil { - return nil, err + return nil, "", err } keyBytes, err := s.Get(name + "_" + keyAlias) if err != nil { - return nil, err + return nil, "", err } // See if the key is encrypted. If its encrypted we'll fail to parse the private key @@ -181,10 +209,10 @@ func getKey(s LimitedFileStore, passphraseRetriever passphrase.Retriever, name s passphrase, giveup, err := passphraseRetriever(name, string(keyAlias), false, attempts) // Check if the passphrase retriever got an error or if it is telling us to give up if giveup || err != nil { - return nil, errors.New("obtaining passphrase failed") + return nil, "", errors.New("obtaining passphrase failed") } if attempts > 10 { - return nil, errors.New("maximum number of passphrase attempts exceeded") + return nil, "", errors.New("maximum number of passphrase attempts exceeded") } // Try to convert PEM encoded bytes back to a PrivateKey using the passphrase @@ -195,7 +223,8 @@ func getKey(s LimitedFileStore, passphraseRetriever passphrase.Retriever, name s } } } - return privKey, nil + cachedKeys[name] = &cachedKey{alias: keyAlias, key: privKey} + return privKey, keyAlias, nil } // ListKeys returns a list of unique PublicKeys present on the KeyFileStore. @@ -213,11 +242,13 @@ func listKeys(s LimitedFileStore) []string { } // RemoveKey removes the key from the keyfilestore -func removeKey(s LimitedFileStore, name string) error { +func removeKey(s LimitedFileStore, cachedKeys map[string]*cachedKey, name string) error { keyAlias, err := getKeyAlias(s, name) if err != nil { return err } + delete(cachedKeys, name) + return s.Remove(name + "_" + keyAlias) } diff --git a/trustmanager/keyfilestore_test.go b/trustmanager/keyfilestore_test.go index a04f026183..1205282c1c 100644 --- a/trustmanager/keyfilestore_test.go +++ b/trustmanager/keyfilestore_test.go @@ -1,14 +1,15 @@ package trustmanager import ( - "bytes" "crypto/rand" "errors" "io/ioutil" "os" "path/filepath" - "strings" "testing" + + "github.com/docker/notary/pkg/passphrase" + "github.com/stretchr/testify/assert" ) var passphraseRetriever = func(keyID string, alias string, createNew bool, numAttempts int) (string, bool, error) { @@ -26,9 +27,7 @@ func TestAddKey(t *testing.T) { // Temporary directory where test files will be created tempBaseDir, err := ioutil.TempDir("", "notary-test-") - if err != nil { - t.Fatalf("failed to create a temporary directory: %v", err) - } + assert.NoError(t, err, "failed to create a temporary directory") defer os.RemoveAll(tempBaseDir) // Since we're generating this manually we need to add the extension '.' @@ -36,30 +35,19 @@ func TestAddKey(t *testing.T) { // Create our store store, err := NewKeyFileStore(tempBaseDir, passphraseRetriever) - if err != nil { - t.Fatalf("failed to create new key filestore: %v", err) - } + assert.NoError(t, err, "failed to create new key filestore") - privKey, err := GenerateRSAKey(rand.Reader, 512) - if err != nil { - t.Fatalf("could not generate private key: %v", err) - } + privKey, err := GenerateECDSAKey(rand.Reader) + assert.NoError(t, err, "could not generate private key") // Call the AddKey function err = store.AddKey(testName, "root", privKey) - if err != nil { - t.Fatalf("failed to add file to store: %v", err) - } + assert.NoError(t, err, "failed to add key to store") // Check to see if file exists b, err := ioutil.ReadFile(expectedFilePath) - if err != nil { - t.Fatalf("expected file not found: %v", err) - } - - if !strings.Contains(string(b), "-----BEGIN RSA PRIVATE KEY-----") { - t.Fatalf("expected private key content in the file: %s", expectedFilePath) - } + assert.NoError(t, err, "expected file not found") + assert.Contains(t, string(b), "-----BEGIN EC PRIVATE KEY-----") } func TestGet(t *testing.T) { @@ -100,39 +88,27 @@ EMl3eFOJXjIch/wIesRSN+2dGOsl7neercjMh1i9RvpCwHDx/E0= // Temporary directory where test files will be created tempBaseDir, err := ioutil.TempDir("", "notary-test-") - if err != nil { - t.Fatalf("failed to create a temporary directory: %v", err) - } + assert.NoError(t, err, "failed to create a temporary directory") defer os.RemoveAll(tempBaseDir) // Since we're generating this manually we need to add the extension '.' filePath := filepath.Join(tempBaseDir, testName+"_"+testAlias+"."+testExt) os.MkdirAll(filepath.Dir(filePath), perms) - if err = ioutil.WriteFile(filePath, testData, perms); err != nil { - t.Fatalf("Failed to write test file: %v", err) - } + err = ioutil.WriteFile(filePath, testData, perms) + assert.NoError(t, err, "failed to write test file") // Create our store store, err := NewKeyFileStore(tempBaseDir, emptyPassphraseRetriever) - if err != nil { - t.Fatalf("failed to create new key filestore: %v", err) - } + assert.NoError(t, err, "failed to create new key filestore") // Call the GetKey function - privKey, err := store.GetKey(testName) - if err != nil { - t.Fatalf("failed to get file from store: %v", err) - } + privKey, _, err := store.GetKey(testName) + assert.NoError(t, err, "failed to get key from store") pemPrivKey, err := KeyToPEM(privKey) - if err != nil { - t.Fatalf("failed to convert key to PEM: %v", err) - } - - if !bytes.Equal(testData, pemPrivKey) { - t.Fatalf("unexpected content in the file: %s", filePath) - } + assert.NoError(t, err, "failed to convert key to PEM") + assert.Equal(t, testData, pemPrivKey) } func TestAddGetKeyMemStore(t *testing.T) { @@ -142,37 +118,20 @@ func TestAddGetKeyMemStore(t *testing.T) { // Create our store store := NewKeyMemoryStore(passphraseRetriever) - privKey, err := GenerateRSAKey(rand.Reader, 512) - if err != nil { - t.Fatalf("could not generate private key: %v", err) - } + privKey, err := GenerateECDSAKey(rand.Reader) + assert.NoError(t, err, "could not generate private key") // Call the AddKey function err = store.AddKey(testName, testAlias, privKey) - if err != nil { - t.Fatalf("failed to add file to store: %v", err) - } + assert.NoError(t, err, "failed to add key to store") // Check to see if file exists - retrievedKey, err := store.GetKey(testName) - if err != nil { - t.Fatalf("failed to get key from store: %v", err) - } + retrievedKey, retrievedAlias, err := store.GetKey(testName) + assert.NoError(t, err, "failed to get key from store") - // Check to see if alias exists - retrievedAlias, err := store.GetKeyAlias(testName) - if err != nil { - t.Fatalf("failed to get key from store: %v", err) - } - - if retrievedAlias != testAlias { - t.Fatalf("retrievedAlias differs getAlias") - } - - if !bytes.Equal(retrievedKey.Public(), privKey.Public()) || - !bytes.Equal(retrievedKey.Private(), privKey.Private()) { - t.Fatalf("key contents differs after add/get") - } + assert.Equal(t, retrievedAlias, testAlias) + assert.Equal(t, retrievedKey.Public(), privKey.Public()) + assert.Equal(t, retrievedKey.Private(), privKey.Private()) } func TestGetDecryptedWithTamperedCipherText(t *testing.T) { testExt := "key" @@ -180,46 +139,38 @@ func TestGetDecryptedWithTamperedCipherText(t *testing.T) { // Temporary directory where test files will be created tempBaseDir, err := ioutil.TempDir("", "notary-test-") - if err != nil { - t.Fatalf("failed to create a temporary directory: %v", err) - } + assert.NoError(t, err, "failed to create a temporary directory") defer os.RemoveAll(tempBaseDir) // Create our FileStore store, err := NewKeyFileStore(tempBaseDir, passphraseRetriever) - if err != nil { - t.Fatalf("failed to create new key filestore: %v", err) - } + assert.NoError(t, err, "failed to create new key filestore") // Generate a new Private Key - privKey, err := GenerateRSAKey(rand.Reader, 512) - if err != nil { - t.Fatalf("could not generate private key: %v", err) - } + privKey, err := GenerateECDSAKey(rand.Reader) + assert.NoError(t, err, "could not generate private key") // Call the AddEncryptedKey function err = store.AddKey(privKey.ID(), testAlias, privKey) - if err != nil { - t.Fatalf("failed to add file to store: %v", err) - } + assert.NoError(t, err, "failed to add key to store") // Since we're generating this manually we need to add the extension '.' expectedFilePath := filepath.Join(tempBaseDir, privKey.ID()+"_"+testAlias+"."+testExt) // Get file description, open file fp, err := os.OpenFile(expectedFilePath, os.O_WRONLY, 0600) - if err != nil { - t.Fatalf("expected file not found: %v", err) - } + assert.NoError(t, err, "expected file not found") // Tamper the file fp.WriteAt([]byte("a"), int64(1)) + // Recreate the KeyFileStore to avoid caching + store, err = NewKeyFileStore(tempBaseDir, passphraseRetriever) + assert.NoError(t, err, "failed to create new key filestore") + // Try to decrypt the file - _, err = store.GetKey(privKey.ID()) - if err == nil { - t.Fatalf("expected error while decrypting the content due to invalid cipher text") - } + _, _, err = store.GetKey(privKey.ID()) + assert.Error(t, err, "expected error while decrypting the content due to invalid cipher text") } func TestGetDecryptedWithInvalidPassphrase(t *testing.T) { @@ -238,26 +189,20 @@ func TestGetDecryptedWithInvalidPassphrase(t *testing.T) { // Temporary directory where test files will be created tempBaseDir, err := ioutil.TempDir("", "notary-test-") - if err != nil { - t.Fatalf("failed to create a temporary directory: %v", err) - } + assert.NoError(t, err, "failed to create a temporary directory") defer os.RemoveAll(tempBaseDir) // Test with KeyFileStore fileStore, err := NewKeyFileStore(tempBaseDir, invalidPassphraseRetriever) - if err != nil { - t.Fatalf("failed to create new key filestore: %v", err) - } + assert.NoError(t, err, "failed to create new key filestore") - testGetDecryptedWithInvalidPassphrase(t, fileStore) + newFileStore, err := NewKeyFileStore(tempBaseDir, invalidPassphraseRetriever) + assert.NoError(t, err, "failed to create new key filestore") - // Test with KeyMemoryStore - memStore := NewKeyMemoryStore(invalidPassphraseRetriever) - if err != nil { - t.Fatalf("failed to create new key memorystore: %v", err) - } - testGetDecryptedWithInvalidPassphrase(t, memStore) + testGetDecryptedWithInvalidPassphrase(t, fileStore, newFileStore) + // Can't test with KeyMemoryStore because we cache the decrypted version of + // the key forever } func TestGetDecryptedWithConsistentlyInvalidPassphrase(t *testing.T) { @@ -271,47 +216,38 @@ func TestGetDecryptedWithConsistentlyInvalidPassphrase(t *testing.T) { // Temporary directory where test files will be created tempBaseDir, err := ioutil.TempDir("", "notary-test-") - if err != nil { - t.Fatalf("failed to create a temporary directory: %v", err) - } + assert.NoError(t, err, "failed to create a temporary directory") defer os.RemoveAll(tempBaseDir) // Test with KeyFileStore fileStore, err := NewKeyFileStore(tempBaseDir, consistentlyInvalidPassphraseRetriever) - if err != nil { - t.Fatalf("failed to create new key filestore: %v", err) - } + assert.NoError(t, err, "failed to create new key filestore") - testGetDecryptedWithInvalidPassphrase(t, fileStore) + newFileStore, err := NewKeyFileStore(tempBaseDir, consistentlyInvalidPassphraseRetriever) + assert.NoError(t, err, "failed to create new key filestore") - // Test with KeyMemoryStore - memStore := NewKeyMemoryStore(consistentlyInvalidPassphraseRetriever) - if err != nil { - t.Fatalf("failed to create new key memorystore: %v", err) - } - testGetDecryptedWithInvalidPassphrase(t, memStore) + testGetDecryptedWithInvalidPassphrase(t, fileStore, newFileStore) + + // Can't test with KeyMemoryStore because we cache the decrypted version of + // the key forever } -func testGetDecryptedWithInvalidPassphrase(t *testing.T, store KeyStore) { +// testGetDecryptedWithInvalidPassphrase takes two keystores so it can add to +// one and get from the other (to work around caching) +func testGetDecryptedWithInvalidPassphrase(t *testing.T, store KeyStore, newStore KeyStore) { testAlias := "root" // Generate a new random RSA Key - privKey, err := GenerateRSAKey(rand.Reader, 512) - if err != nil { - t.Fatalf("could not generate private key: %v", err) - } + privKey, err := GenerateECDSAKey(rand.Reader) + assert.NoError(t, err, "could not generate private key") // Call the AddKey function err = store.AddKey(privKey.ID(), testAlias, privKey) - if err != nil { - t.Fatalf("failed to add file to store: %v", err) - } + assert.NoError(t, err, "failed to add key to store") // Try to decrypt the file with an invalid passphrase - _, err = store.GetKey(privKey.ID()) - if err == nil { - t.Fatalf("expected error while decrypting the content due to invalid passphrase") - } + _, _, err = newStore.GetKey(privKey.ID()) + assert.Error(t, err, "expected error while decrypting the content due to invalid passphrase") } func TestRemoveKey(t *testing.T) { @@ -321,9 +257,7 @@ func TestRemoveKey(t *testing.T) { // Temporary directory where test files will be created tempBaseDir, err := ioutil.TempDir("", "notary-test-") - if err != nil { - t.Fatalf("failed to create a temporary directory: %v", err) - } + assert.NoError(t, err, "failed to create a temporary directory") defer os.RemoveAll(tempBaseDir) // Since we're generating this manually we need to add the extension '.' @@ -331,36 +265,82 @@ func TestRemoveKey(t *testing.T) { // Create our store store, err := NewKeyFileStore(tempBaseDir, passphraseRetriever) - if err != nil { - t.Fatalf("failed to create new key filestore: %v", err) - } + assert.NoError(t, err, "failed to create new key filestore") - privKey, err := GenerateRSAKey(rand.Reader, 512) - if err != nil { - t.Fatalf("could not generate private key: %v", err) - } + privKey, err := GenerateECDSAKey(rand.Reader) + assert.NoError(t, err, "could not generate private key") // Call the AddKey function err = store.AddKey(testName, testAlias, privKey) - if err != nil { - t.Fatalf("failed to add file to store: %v", err) - } + assert.NoError(t, err, "failed to add key to store") // Check to see if file exists _, err = ioutil.ReadFile(expectedFilePath) - if err != nil { - t.Fatalf("expected file not found: %v", err) - } + assert.NoError(t, err, "expected file not found") // Call remove key err = store.RemoveKey(testName) - if err != nil { - t.Fatalf("unable to remove key: %v", err) - } + assert.NoError(t, err, "unable to remove key") // Check to see if file still exists _, err = ioutil.ReadFile(expectedFilePath) - if err == nil { - t.Fatalf("file should not exist %s", expectedFilePath) - } + assert.Error(t, err, "file should not exist") +} + +func TestKeysAreCached(t *testing.T) { + testName := "docker.com/notary/root" + testAlias := "alias" + + // Temporary directory where test files will be created + tempBaseDir, err := ioutil.TempDir("", "notary-test-") + assert.NoError(t, err, "failed to create a temporary directory") + defer os.RemoveAll(tempBaseDir) + + var countingPassphraseRetriever passphrase.Retriever + + numTimesCalled := 0 + countingPassphraseRetriever = func(keyId, alias string, createNew bool, attempts int) (passphrase string, giveup bool, err error) { + numTimesCalled++ + return "password", false, nil + } + + // Create our store + store, err := NewKeyFileStore(tempBaseDir, countingPassphraseRetriever) + assert.NoError(t, err, "failed to create new key filestore") + + privKey, err := GenerateECDSAKey(rand.Reader) + assert.NoError(t, err, "could not generate private key") + + // Call the AddKey function + err = store.AddKey(testName, testAlias, privKey) + assert.NoError(t, err, "failed to add key to store") + + assert.Equal(t, 1, numTimesCalled, "numTimesCalled should have been 1") + + // Call the AddKey function + privKey2, _, err := store.GetKey(testName) + assert.NoError(t, err, "failed to add key to store") + + assert.Equal(t, privKey.Public(), privKey2.Public(), "cachedPrivKey should be the same as the added privKey") + assert.Equal(t, privKey.Private(), privKey2.Private(), "cachedPrivKey should be the same as the added privKey") + assert.Equal(t, 1, numTimesCalled, "numTimesCalled should be 1 -- no additional call to passphraseRetriever") + + // Create a new store + store2, err := NewKeyFileStore(tempBaseDir, countingPassphraseRetriever) + assert.NoError(t, err, "failed to create new key filestore") + + // Call the GetKey function + privKey3, _, err := store2.GetKey(testName) + assert.NoError(t, err, "failed to get key from store") + + assert.Equal(t, privKey2.Private(), privKey3.Private(), "privkey from store1 should be the same as privkey from store2") + assert.Equal(t, privKey2.Public(), privKey3.Public(), "privkey from store1 should be the same as privkey from store2") + assert.Equal(t, 2, numTimesCalled, "numTimesCalled should be 2 -- one additional call to passphraseRetriever") + + // Call the GetKey function a bunch of times + for i := 0; i < 10; i++ { + _, _, err := store2.GetKey(testName) + assert.NoError(t, err, "failed to get key from store") + } + assert.Equal(t, 2, numTimesCalled, "numTimesCalled should be 2 -- no additional call to passphraseRetriever") }