diff --git a/daemon/daemon.go b/daemon/daemon.go index a3c307ddad..6ccf8aaf2a 100644 --- a/daemon/daemon.go +++ b/daemon/daemon.go @@ -38,6 +38,7 @@ import ( "github.com/docker/docker/pkg/sysinfo" "github.com/docker/docker/pkg/truncindex" "github.com/docker/docker/runconfig" + "github.com/docker/docker/trust" "github.com/docker/docker/utils" "github.com/docker/docker/volumes" ) @@ -98,6 +99,7 @@ type Daemon struct { containerGraph *graphdb.Database driver graphdriver.Driver execDriver execdriver.Driver + trustStore *trust.TrustStore } // Install installs daemon capabilities to eng. @@ -136,6 +138,9 @@ func (daemon *Daemon) Install(eng *engine.Engine) error { if err := daemon.Repositories().Install(eng); err != nil { return err } + if err := daemon.trustStore.Install(eng); err != nil { + return err + } // FIXME: this hack is necessary for legacy integration tests to access // the daemon object. eng.Hack_SetGlobalVar("httpapi.daemon", daemon) @@ -835,6 +840,15 @@ func NewDaemonFromDirectory(config *Config, eng *engine.Engine) (*Daemon, error) return nil, fmt.Errorf("Couldn't create Tag store: %s", err) } + trustDir := path.Join(config.Root, "trust") + if err := os.MkdirAll(trustDir, 0700); err != nil && !os.IsExist(err) { + return nil, err + } + t, err := trust.NewTrustStore(trustDir) + if err != nil { + return nil, fmt.Errorf("could not create trust store: %s", err) + } + if !config.DisableNetwork { job := eng.Job("init_networkdriver") @@ -899,6 +913,7 @@ func NewDaemonFromDirectory(config *Config, eng *engine.Engine) (*Daemon, error) sysInitPath: sysInitPath, execDriver: ed, eng: eng, + trustStore: t, } if err := daemon.checkLocaldns(); err != nil { return nil, err diff --git a/graph/pull.go b/graph/pull.go index fb1aff7ac0..d419b4f7ca 100644 --- a/graph/pull.go +++ b/graph/pull.go @@ -1,10 +1,14 @@ package graph import ( + "bytes" + "encoding/json" "fmt" "io" + "io/ioutil" "net" "net/url" + "os" "strings" "time" @@ -13,8 +17,60 @@ import ( "github.com/docker/docker/pkg/log" "github.com/docker/docker/registry" "github.com/docker/docker/utils" + "github.com/docker/libtrust" ) +func (s *TagStore) verifyManifest(eng *engine.Engine, manifestBytes []byte) (*registry.ManifestData, bool, error) { + sig, err := libtrust.ParsePrettySignature(manifestBytes, "signatures") + if err != nil { + return nil, false, fmt.Errorf("error parsing payload: %s", err) + } + keys, err := sig.Verify() + if err != nil { + return nil, false, fmt.Errorf("error verifying payload: %s", err) + } + + payload, err := sig.Payload() + if err != nil { + return nil, false, fmt.Errorf("error retrieving payload: %s", err) + } + + var manifest registry.ManifestData + if err := json.Unmarshal(payload, &manifest); err != nil { + return nil, false, fmt.Errorf("error unmarshalling manifest: %s", err) + } + + var verified bool + for _, key := range keys { + job := eng.Job("trust_key_check") + b, err := key.MarshalJSON() + if err != nil { + return nil, false, fmt.Errorf("error marshalling public key: %s", err) + } + namespace := manifest.Name + if namespace[0] != '/' { + namespace = "/" + namespace + } + stdoutBuffer := bytes.NewBuffer(nil) + + job.Args = append(job.Args, namespace) + job.Setenv("PublicKey", string(b)) + // Check key has read/write permission (0x03) + job.SetenvInt("Permission", 0x03) + job.Stdout.Add(stdoutBuffer) + if err = job.Run(); err != nil { + return nil, false, fmt.Errorf("error running key check: %s", err) + } + result := engine.Tail(stdoutBuffer, 1) + log.Debugf("Key check result: %q", result) + if result == "verified" { + verified = true + } + } + + return &manifest, verified, nil +} + func (s *TagStore) CmdPull(job *engine.Job) engine.Status { if n := len(job.Args); n != 1 && n != 2 { return job.Errorf("Usage: %s IMAGE [TAG]", job.Name) @@ -52,7 +108,7 @@ func (s *TagStore) CmdPull(job *engine.Job) engine.Status { return job.Error(err) } - endpoint, err := registry.ExpandAndVerifyRegistryUrl(hostname) + endpoint, err := registry.NewEndpoint(hostname) if err != nil { return job.Error(err) } @@ -62,14 +118,32 @@ func (s *TagStore) CmdPull(job *engine.Job) engine.Status { return job.Error(err) } - if endpoint == registry.IndexServerAddress() { + var isOfficial bool + if endpoint.VersionString(1) == registry.IndexServerAddress() { // If pull "index.docker.io/foo/bar", it's stored locally under "foo/bar" localName = remoteName + isOfficial = isOfficialName(remoteName) + if isOfficial && strings.IndexRune(remoteName, '/') == -1 { + remoteName = "library/" + remoteName + } + // Use provided mirrors, if any mirrors = s.mirrors } + if isOfficial || endpoint.Version == registry.APIVersion2 { + j := job.Eng.Job("trust_update_base") + if err = j.Run(); err != nil { + return job.Errorf("error updating trust base graph: %s", err) + } + + if err := s.pullV2Repository(job.Eng, r, job.Stdout, localName, remoteName, tag, sf, job.GetenvBool("parallel")); err == nil { + return engine.StatusOK + } else if err != registry.ErrDoesNotExist { + log.Errorf("Error from V2 registry: %s", err) + } + } if err = s.pullRepository(r, job.Stdout, localName, remoteName, tag, sf, job.GetenvBool("parallel"), mirrors); err != nil { return job.Error(err) } @@ -337,3 +411,169 @@ func WriteStatus(requestedTag string, out io.Writer, sf *utils.StreamFormatter, out.Write(sf.FormatStatus("", "Status: Image is up to date for %s", requestedTag)) } } + +// downloadInfo is used to pass information from download to extractor +type downloadInfo struct { + imgJSON []byte + img *image.Image + tmpFile *os.File + length int64 + downloaded bool + err chan error +} + +func (s *TagStore) pullV2Repository(eng *engine.Engine, r *registry.Session, out io.Writer, localName, remoteName, tag string, sf *utils.StreamFormatter, parallel bool) error { + if tag == "" { + log.Debugf("Pulling tag list from V2 registry for %s", remoteName) + tags, err := r.GetV2RemoteTags(remoteName, nil) + if err != nil { + return err + } + for _, t := range tags { + if err := s.pullV2Tag(eng, r, out, localName, remoteName, t, sf, parallel); err != nil { + return err + } + } + } else { + if err := s.pullV2Tag(eng, r, out, localName, remoteName, tag, sf, parallel); err != nil { + return err + } + } + + return nil +} + +func (s *TagStore) pullV2Tag(eng *engine.Engine, r *registry.Session, out io.Writer, localName, remoteName, tag string, sf *utils.StreamFormatter, parallel bool) error { + log.Debugf("Pulling tag from V2 registry: %q", tag) + manifestBytes, err := r.GetV2ImageManifest(remoteName, tag, nil) + if err != nil { + return err + } + + manifest, verified, err := s.verifyManifest(eng, manifestBytes) + if err != nil { + return fmt.Errorf("error verifying manifest: %s", err) + } + + if len(manifest.BlobSums) != len(manifest.History) { + return fmt.Errorf("length of history not equal to number of layers") + } + + if verified { + out.Write(sf.FormatStatus("", "The image you are pulling has been digitally signed by Docker, Inc.")) + } + out.Write(sf.FormatStatus(tag, "Pulling from %s", localName)) + + downloads := make([]downloadInfo, len(manifest.BlobSums)) + + for i := len(manifest.BlobSums) - 1; i >= 0; i-- { + var ( + sumStr = manifest.BlobSums[i] + imgJSON = []byte(manifest.History[i]) + ) + + img, err := image.NewImgJSON(imgJSON) + if err != nil { + return fmt.Errorf("failed to parse json: %s", err) + } + downloads[i].img = img + + // Check if exists + if s.graph.Exists(img.ID) { + log.Debugf("Image already exists: %s", img.ID) + continue + } + + chunks := strings.SplitN(sumStr, ":", 2) + if len(chunks) < 2 { + return fmt.Errorf("expected 2 parts in the sumStr, got %#v", chunks) + } + sumType, checksum := chunks[0], chunks[1] + out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Pulling fs layer", nil)) + + downloadFunc := func(di *downloadInfo) error { + log.Infof("pulling blob %q to V1 img %s", sumStr, img.ID) + + if c, err := s.poolAdd("pull", "img:"+img.ID); err != nil { + if c != nil { + out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Layer already being pulled by another client. Waiting.", nil)) + <-c + out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Download complete", nil)) + } else { + log.Debugf("Image (id: %s) pull is already running, skipping: %v", img.ID, err) + } + } else { + tmpFile, err := ioutil.TempFile("", "GetV2ImageBlob") + if err != nil { + return err + } + + r, l, err := r.GetV2ImageBlobReader(remoteName, sumType, checksum, nil) + if err != nil { + return err + } + defer r.Close() + io.Copy(tmpFile, utils.ProgressReader(r, int(l), out, sf, false, utils.TruncateID(img.ID), "Downloading")) + + out.Write(sf.FormatProgress(utils.TruncateID(img.ID), "Download complete", nil)) + + log.Debugf("Downloaded %s to tempfile %s", img.ID, tmpFile.Name()) + di.tmpFile = tmpFile + di.length = l + di.downloaded = true + } + di.imgJSON = imgJSON + defer s.poolRemove("pull", "img:"+img.ID) + + return nil + } + + if parallel { + downloads[i].err = make(chan error) + go func(di *downloadInfo) { + di.err <- downloadFunc(di) + }(&downloads[i]) + } else { + err := downloadFunc(&downloads[i]) + if err != nil { + return err + } + } + } + + for i := len(downloads) - 1; i >= 0; i-- { + d := &downloads[i] + if d.err != nil { + err := <-d.err + if err != nil { + return err + } + } + if d.downloaded { + // if tmpFile is empty assume download and extracted elsewhere + defer os.Remove(d.tmpFile.Name()) + defer d.tmpFile.Close() + d.tmpFile.Seek(0, 0) + if d.tmpFile != nil { + err = s.graph.Register(d.img, d.imgJSON, + utils.ProgressReader(d.tmpFile, int(d.length), out, sf, false, utils.TruncateID(d.img.ID), "Extracting")) + if err != nil { + return err + } + + // FIXME: Pool release here for parallel tag pull (ensures any downloads block until fully extracted) + } + out.Write(sf.FormatProgress(utils.TruncateID(d.img.ID), "Pull complete", nil)) + + } else { + out.Write(sf.FormatProgress(utils.TruncateID(d.img.ID), "Already exists", nil)) + } + + } + + if err = s.Set(localName, tag, downloads[0].img.ID, true); err != nil { + return err + } + + return nil +} diff --git a/graph/push.go b/graph/push.go index 657e147d10..3511245b30 100644 --- a/graph/push.go +++ b/graph/push.go @@ -214,7 +214,7 @@ func (s *TagStore) CmdPush(job *engine.Job) engine.Status { return job.Error(err) } - endpoint, err := registry.ExpandAndVerifyRegistryUrl(hostname) + endpoint, err := registry.NewEndpoint(hostname) if err != nil { return job.Error(err) } @@ -243,7 +243,7 @@ func (s *TagStore) CmdPush(job *engine.Job) engine.Status { var token []string job.Stdout.Write(sf.FormatStatus("", "The push refers to an image: [%s]", localName)) - if _, err := s.pushImage(r, job.Stdout, remoteName, img.ID, endpoint, token, sf); err != nil { + if _, err := s.pushImage(r, job.Stdout, remoteName, img.ID, endpoint.String(), token, sf); err != nil { return job.Error(err) } return engine.StatusOK diff --git a/graph/tags.go b/graph/tags.go index 6643ebceac..150bdb9640 100644 --- a/graph/tags.go +++ b/graph/tags.go @@ -276,6 +276,20 @@ func (store *TagStore) GetRepoRefs() map[string][]string { return reporefs } +// isOfficialName returns whether a repo name is considered an official +// repository. Official repositories are repos with names within +// the library namespace or which default to the library namespace +// by not providing one. +func isOfficialName(name string) bool { + if strings.HasPrefix(name, "library/") { + return true + } + if strings.IndexRune(name, '/') == -1 { + return true + } + return false +} + // Validate the name of a repository func validateRepoName(name string) error { if name == "" { diff --git a/graph/tags_unit_test.go b/graph/tags_unit_test.go index d3111cfd65..e4f1fb809f 100644 --- a/graph/tags_unit_test.go +++ b/graph/tags_unit_test.go @@ -2,15 +2,16 @@ package graph import ( "bytes" + "io" + "os" + "path" + "testing" + "github.com/docker/docker/daemon/graphdriver" _ "github.com/docker/docker/daemon/graphdriver/vfs" // import the vfs driver so it is used in the tests "github.com/docker/docker/image" "github.com/docker/docker/utils" "github.com/docker/docker/vendor/src/code.google.com/p/go/src/pkg/archive/tar" - "io" - "os" - "path" - "testing" ) const ( @@ -132,3 +133,18 @@ func TestInvalidTagName(t *testing.T) { } } } + +func TestOfficialName(t *testing.T) { + names := map[string]bool{ + "library/ubuntu": true, + "nonlibrary/ubuntu": false, + "ubuntu": true, + "other/library": false, + } + for name, isOfficial := range names { + result := isOfficialName(name) + if result != isOfficial { + t.Errorf("Unexpected result for %s\n\tExpecting: %v\n\tActual: %v", name, isOfficial, result) + } + } +} diff --git a/hack/vendor.sh b/hack/vendor.sh index 5bf8b601f5..cbdef927a2 100755 --- a/hack/vendor.sh +++ b/hack/vendor.sh @@ -51,7 +51,7 @@ clone hg code.google.com/p/go.net 84a4013f96e0 clone hg code.google.com/p/gosqlite 74691fb6f837 -clone git github.com/docker/libtrust 136d534cc940 +clone git github.com/docker/libtrust d273ef2565ca # get Go tip's archive/tar, for xattr support and improved performance # TODO after Go 1.4 drops, bump our minimum supported version and drop this vendored dep diff --git a/registry/endpoint.go b/registry/endpoint.go new file mode 100644 index 0000000000..5313a8079f --- /dev/null +++ b/registry/endpoint.go @@ -0,0 +1,129 @@ +package registry + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + + "github.com/docker/docker/pkg/log" +) + +// scans string for api version in the URL path. returns the trimmed hostname, if version found, string and API version. +func scanForApiVersion(hostname string) (string, APIVersion) { + var ( + chunks []string + apiVersionStr string + ) + if strings.HasSuffix(hostname, "/") { + chunks = strings.Split(hostname[:len(hostname)-1], "/") + apiVersionStr = chunks[len(chunks)-1] + } else { + chunks = strings.Split(hostname, "/") + apiVersionStr = chunks[len(chunks)-1] + } + for k, v := range apiVersions { + if apiVersionStr == v { + hostname = strings.Join(chunks[:len(chunks)-1], "/") + return hostname, k + } + } + return hostname, DefaultAPIVersion +} + +func NewEndpoint(hostname string) (*Endpoint, error) { + var ( + endpoint Endpoint + trimmedHostname string + err error + ) + if !strings.HasPrefix(hostname, "http") { + hostname = "https://" + hostname + } + trimmedHostname, endpoint.Version = scanForApiVersion(hostname) + endpoint.URL, err = url.Parse(trimmedHostname) + if err != nil { + return nil, err + } + + endpoint.URL.Scheme = "https" + if _, err := endpoint.Ping(); err != nil { + log.Debugf("Registry %s does not work (%s), falling back to http", endpoint, err) + // TODO: Check if http fallback is enabled + endpoint.URL.Scheme = "http" + if _, err = endpoint.Ping(); err != nil { + return nil, errors.New("Invalid Registry endpoint: " + err.Error()) + } + } + + return &endpoint, nil +} + +type Endpoint struct { + URL *url.URL + Version APIVersion +} + +// Get the formated URL for the root of this registry Endpoint +func (e Endpoint) String() string { + return fmt.Sprintf("%s/v%d/", e.URL.String(), e.Version) +} + +func (e Endpoint) VersionString(version APIVersion) string { + return fmt.Sprintf("%s/v%d/", e.URL.String(), version) +} + +func (e Endpoint) Ping() (RegistryInfo, error) { + if e.String() == IndexServerAddress() { + // Skip the check, we now this one is valid + // (and we never want to fallback to http in case of error) + return RegistryInfo{Standalone: false}, nil + } + + req, err := http.NewRequest("GET", e.String()+"_ping", nil) + if err != nil { + return RegistryInfo{Standalone: false}, err + } + + resp, _, err := doRequest(req, nil, ConnectTimeout) + if err != nil { + return RegistryInfo{Standalone: false}, err + } + + defer resp.Body.Close() + + jsonString, err := ioutil.ReadAll(resp.Body) + if err != nil { + return RegistryInfo{Standalone: false}, fmt.Errorf("Error while reading the http response: %s", err) + } + + // If the header is absent, we assume true for compatibility with earlier + // versions of the registry. default to true + info := RegistryInfo{ + Standalone: true, + } + if err := json.Unmarshal(jsonString, &info); err != nil { + log.Debugf("Error unmarshalling the _ping RegistryInfo: %s", err) + // don't stop here. Just assume sane defaults + } + if hdr := resp.Header.Get("X-Docker-Registry-Version"); hdr != "" { + log.Debugf("Registry version header: '%s'", hdr) + info.Version = hdr + } + log.Debugf("RegistryInfo.Version: %q", info.Version) + + standalone := resp.Header.Get("X-Docker-Registry-Standalone") + log.Debugf("Registry standalone header: '%s'", standalone) + // Accepted values are "true" (case-insensitive) and "1". + if strings.EqualFold(standalone, "true") || standalone == "1" { + info.Standalone = true + } else if len(standalone) > 0 { + // there is a header set, and it is not "true" or "1", so assume fails + info.Standalone = false + } + log.Debugf("RegistryInfo.Standalone: %t", info.Standalone) + return info, nil +} diff --git a/registry/registry.go b/registry/registry.go index a2e1fbdd1d..fd74b7514e 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -3,7 +3,6 @@ package registry import ( "crypto/tls" "crypto/x509" - "encoding/json" "errors" "fmt" "io/ioutil" @@ -15,13 +14,13 @@ import ( "strings" "time" - "github.com/docker/docker/pkg/log" "github.com/docker/docker/utils" ) var ( ErrAlreadyExists = errors.New("Image already exists") ErrInvalidRepositoryName = errors.New("Invalid repository name (ex: \"registry.domain.tld/myrepos\")") + ErrDoesNotExist = errors.New("Image does not exist") errLoginRequired = errors.New("Authentication is required.") validHex = regexp.MustCompile(`^([a-f0-9]{64})$`) validNamespace = regexp.MustCompile(`^([a-z0-9_]{4,30})$`) @@ -152,55 +151,6 @@ func doRequest(req *http.Request, jar http.CookieJar, timeout TimeoutType) (*htt return nil, nil, nil } -func pingRegistryEndpoint(endpoint string) (RegistryInfo, error) { - if endpoint == IndexServerAddress() { - // Skip the check, we now this one is valid - // (and we never want to fallback to http in case of error) - return RegistryInfo{Standalone: false}, nil - } - - req, err := http.NewRequest("GET", endpoint+"_ping", nil) - if err != nil { - return RegistryInfo{Standalone: false}, err - } - - resp, _, err := doRequest(req, nil, ConnectTimeout) - if err != nil { - return RegistryInfo{Standalone: false}, err - } - - defer resp.Body.Close() - - jsonString, err := ioutil.ReadAll(resp.Body) - if err != nil { - return RegistryInfo{Standalone: false}, fmt.Errorf("Error while reading the http response: %s", err) - } - - // If the header is absent, we assume true for compatibility with earlier - // versions of the registry. default to true - info := RegistryInfo{ - Standalone: true, - } - if err := json.Unmarshal(jsonString, &info); err != nil { - log.Debugf("Error unmarshalling the _ping RegistryInfo: %s", err) - // don't stop here. Just assume sane defaults - } - if hdr := resp.Header.Get("X-Docker-Registry-Version"); hdr != "" { - log.Debugf("Registry version header: '%s'", hdr) - info.Version = hdr - } - log.Debugf("RegistryInfo.Version: %q", info.Version) - - standalone := resp.Header.Get("X-Docker-Registry-Standalone") - log.Debugf("Registry standalone header: '%s'", standalone) - if !strings.EqualFold(standalone, "true") && standalone != "1" && len(standalone) > 0 { - // there is a header set, and it is not "true" or "1", so assume fails - info.Standalone = false - } - log.Debugf("RegistryInfo.Standalone: %q", info.Standalone) - return info, nil -} - func validateRepositoryName(repositoryName string) error { var ( namespace string @@ -252,33 +202,6 @@ func ResolveRepositoryName(reposName string) (string, string, error) { return hostname, reposName, nil } -// this method expands the registry name as used in the prefix of a repo -// to a full url. if it already is a url, there will be no change. -// The registry is pinged to test if it http or https -func ExpandAndVerifyRegistryUrl(hostname string) (string, error) { - if strings.HasPrefix(hostname, "http:") || strings.HasPrefix(hostname, "https:") { - // if there is no slash after https:// (8 characters) then we have no path in the url - if strings.LastIndex(hostname, "/") < 9 { - // there is no path given. Expand with default path - hostname = hostname + "/v1/" - } - if _, err := pingRegistryEndpoint(hostname); err != nil { - return "", errors.New("Invalid Registry endpoint: " + err.Error()) - } - return hostname, nil - } - endpoint := fmt.Sprintf("https://%s/v1/", hostname) - if _, err := pingRegistryEndpoint(endpoint); err != nil { - log.Debugf("Registry %s does not work (%s), falling back to http", endpoint, err) - endpoint = fmt.Sprintf("http://%s/v1/", hostname) - if _, err = pingRegistryEndpoint(endpoint); err != nil { - //TODO: triggering highland build can be done there without "failing" - return "", errors.New("Invalid Registry endpoint: " + err.Error()) - } - } - return endpoint, nil -} - func trustedLocation(req *http.Request) bool { var ( trusteds = []string{"docker.com", "docker.io"} diff --git a/registry/registry_mock_test.go b/registry/registry_mock_test.go index 8851dcbe39..379dc78f47 100644 --- a/registry/registry_mock_test.go +++ b/registry/registry_mock_test.go @@ -83,6 +83,8 @@ var ( func init() { r := mux.NewRouter() + + // /v1/ r.HandleFunc("/v1/_ping", handlerGetPing).Methods("GET") r.HandleFunc("/v1/images/{image_id:[^/]+}/{action:json|layer|ancestry}", handlerGetImage).Methods("GET") r.HandleFunc("/v1/images/{image_id:[^/]+}/{action:json|layer|checksum}", handlerPutImage).Methods("PUT") @@ -93,6 +95,10 @@ func init() { r.HandleFunc("/v1/repositories/{repository:.+}{action:/images|/}", handlerImages).Methods("GET", "PUT", "DELETE") r.HandleFunc("/v1/repositories/{repository:.+}/auth", handlerAuth).Methods("PUT") r.HandleFunc("/v1/search", handlerSearch).Methods("GET") + + // /v2/ + r.HandleFunc("/v2/version", handlerGetPing).Methods("GET") + testHttpServer = httptest.NewServer(handlerAccessLog(r)) } diff --git a/registry/registry_test.go b/registry/registry_test.go index 8a95221dc7..ab4178126a 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -18,7 +18,11 @@ var ( func spawnTestRegistrySession(t *testing.T) *Session { authConfig := &AuthConfig{} - r, err := NewSession(authConfig, utils.NewHTTPRequestFactory(), makeURL("/v1/"), true) + endpoint, err := NewEndpoint(makeURL("/v1/")) + if err != nil { + t.Fatal(err) + } + r, err := NewSession(authConfig, utils.NewHTTPRequestFactory(), endpoint, true) if err != nil { t.Fatal(err) } @@ -26,7 +30,11 @@ func spawnTestRegistrySession(t *testing.T) *Session { } func TestPingRegistryEndpoint(t *testing.T) { - regInfo, err := pingRegistryEndpoint(makeURL("/v1/")) + ep, err := NewEndpoint(makeURL("/v1/")) + if err != nil { + t.Fatal(err) + } + regInfo, err := ep.Ping() if err != nil { t.Fatal(err) } @@ -197,7 +205,7 @@ func TestPushImageJSONIndex(t *testing.T) { if repoData == nil { t.Fatal("Expected RepositoryData object") } - repoData, err = r.PushImageJSONIndex("foo42/bar", imgData, true, []string{r.indexEndpoint}) + repoData, err = r.PushImageJSONIndex("foo42/bar", imgData, true, []string{r.indexEndpoint.String()}) if err != nil { t.Fatal(err) } diff --git a/registry/service.go b/registry/service.go index 0e6f1bda93..f7b353000e 100644 --- a/registry/service.go +++ b/registry/service.go @@ -40,11 +40,14 @@ func (s *Service) Auth(job *engine.Job) engine.Status { job.GetenvJson("authConfig", authConfig) // TODO: this is only done here because auth and registry need to be merged into one pkg if addr := authConfig.ServerAddress; addr != "" && addr != IndexServerAddress() { - addr, err = ExpandAndVerifyRegistryUrl(addr) + endpoint, err := NewEndpoint(addr) if err != nil { return job.Error(err) } - authConfig.ServerAddress = addr + if _, err := endpoint.Ping(); err != nil { + return job.Error(err) + } + authConfig.ServerAddress = endpoint.String() } status, err := Login(authConfig, HTTPRequestFactory(nil)) if err != nil { @@ -86,11 +89,11 @@ func (s *Service) Search(job *engine.Job) engine.Status { if err != nil { return job.Error(err) } - hostname, err = ExpandAndVerifyRegistryUrl(hostname) + endpoint, err := NewEndpoint(hostname) if err != nil { return job.Error(err) } - r, err := NewSession(authConfig, HTTPRequestFactory(metaHeaders), hostname, true) + r, err := NewSession(authConfig, HTTPRequestFactory(metaHeaders), endpoint, true) if err != nil { return job.Error(err) } diff --git a/registry/session.go b/registry/session.go index 58263ef6e3..5067b8d5de 100644 --- a/registry/session.go +++ b/registry/session.go @@ -25,15 +25,15 @@ import ( type Session struct { authConfig *AuthConfig reqFactory *utils.HTTPRequestFactory - indexEndpoint string + indexEndpoint *Endpoint jar *cookiejar.Jar timeout TimeoutType } -func NewSession(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, indexEndpoint string, timeout bool) (r *Session, err error) { +func NewSession(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, endpoint *Endpoint, timeout bool) (r *Session, err error) { r = &Session{ authConfig: authConfig, - indexEndpoint: indexEndpoint, + indexEndpoint: endpoint, } if timeout { @@ -47,13 +47,13 @@ func NewSession(authConfig *AuthConfig, factory *utils.HTTPRequestFactory, index // If we're working with a standalone private registry over HTTPS, send Basic Auth headers // alongside our requests. - if indexEndpoint != IndexServerAddress() && strings.HasPrefix(indexEndpoint, "https://") { - info, err := pingRegistryEndpoint(indexEndpoint) + if r.indexEndpoint.VersionString(1) != IndexServerAddress() && r.indexEndpoint.URL.Scheme == "https" { + info, err := r.indexEndpoint.Ping() if err != nil { return nil, err } if info.Standalone { - log.Debugf("Endpoint %s is eligible for private registry registry. Enabling decorator.", indexEndpoint) + log.Debugf("Endpoint %s is eligible for private registry registry. Enabling decorator.", r.indexEndpoint.String()) dec := utils.NewHTTPAuthDecorator(authConfig.Username, authConfig.Password) factory.AddDecorator(dec) } @@ -261,8 +261,7 @@ func buildEndpointsList(headers []string, indexEp string) ([]string, error) { } func (r *Session) GetRepositoryData(remote string) (*RepositoryData, error) { - indexEp := r.indexEndpoint - repositoryTarget := fmt.Sprintf("%srepositories/%s/images", indexEp, remote) + repositoryTarget := fmt.Sprintf("%srepositories/%s/images", r.indexEndpoint.VersionString(1), remote) log.Debugf("[registry] Calling GET %s", repositoryTarget) @@ -296,17 +295,13 @@ func (r *Session) GetRepositoryData(remote string) (*RepositoryData, error) { var endpoints []string if res.Header.Get("X-Docker-Endpoints") != "" { - endpoints, err = buildEndpointsList(res.Header["X-Docker-Endpoints"], indexEp) + endpoints, err = buildEndpointsList(res.Header["X-Docker-Endpoints"], r.indexEndpoint.VersionString(1)) if err != nil { return nil, err } } else { // Assume the endpoint is on the same host - u, err := url.Parse(indexEp) - if err != nil { - return nil, err - } - endpoints = append(endpoints, fmt.Sprintf("%s://%s/v1/", u.Scheme, req.URL.Host)) + endpoints = append(endpoints, fmt.Sprintf("%s://%s/v1/", r.indexEndpoint.URL.Scheme, req.URL.Host)) } checksumsJSON, err := ioutil.ReadAll(res.Body) @@ -474,7 +469,6 @@ func (r *Session) PushRegistryTag(remote, revision, tag, registry string, token func (r *Session) PushImageJSONIndex(remote string, imgList []*ImgData, validate bool, regs []string) (*RepositoryData, error) { cleanImgList := []*ImgData{} - indexEp := r.indexEndpoint if validate { for _, elem := range imgList { @@ -494,7 +488,7 @@ func (r *Session) PushImageJSONIndex(remote string, imgList []*ImgData, validate if validate { suffix = "images" } - u := fmt.Sprintf("%srepositories/%s/%s", indexEp, remote, suffix) + u := fmt.Sprintf("%srepositories/%s/%s", r.indexEndpoint.VersionString(1), remote, suffix) log.Debugf("[registry] PUT %s", u) log.Debugf("Image list pushed to index:\n%s", imgListJSON) req, err := r.reqFactory.NewRequest("PUT", u, bytes.NewReader(imgListJSON)) @@ -552,7 +546,7 @@ func (r *Session) PushImageJSONIndex(remote string, imgList []*ImgData, validate } if res.Header.Get("X-Docker-Endpoints") != "" { - endpoints, err = buildEndpointsList(res.Header["X-Docker-Endpoints"], indexEp) + endpoints, err = buildEndpointsList(res.Header["X-Docker-Endpoints"], r.indexEndpoint.VersionString(1)) if err != nil { return nil, err } @@ -578,7 +572,7 @@ func (r *Session) PushImageJSONIndex(remote string, imgList []*ImgData, validate func (r *Session) SearchRepositories(term string) (*SearchResults, error) { log.Debugf("Index server: %s", r.indexEndpoint) - u := r.indexEndpoint + "search?q=" + url.QueryEscape(term) + u := r.indexEndpoint.VersionString(1) + "search?q=" + url.QueryEscape(term) req, err := r.reqFactory.NewRequest("GET", u, nil) if err != nil { return nil, err diff --git a/registry/session_v2.go b/registry/session_v2.go new file mode 100644 index 0000000000..2e0de49bc2 --- /dev/null +++ b/registry/session_v2.go @@ -0,0 +1,386 @@ +package registry + +import ( + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/url" + "strconv" + + "github.com/docker/docker/pkg/log" + "github.com/docker/docker/utils" + "github.com/gorilla/mux" +) + +func newV2RegistryRouter() *mux.Router { + router := mux.NewRouter() + + v2Router := router.PathPrefix("/v2/").Subrouter() + + // Version Info + v2Router.Path("/version").Name("version") + + // Image Manifests + v2Router.Path("/manifest/{imagename:[a-z0-9-._/]+}/{tagname:[a-zA-Z0-9-._]+}").Name("manifests") + + // List Image Tags + v2Router.Path("/tags/{imagename:[a-z0-9-._/]+}").Name("tags") + + // Download a blob + v2Router.Path("/blob/{imagename:[a-z0-9-._/]+}/{sumtype:[a-z0-9_+-]+}/{sum:[a-fA-F0-9]{4,}}").Name("downloadBlob") + + // Upload a blob + v2Router.Path("/blob/{imagename:[a-z0-9-._/]+}/{sumtype:[a-z0-9_+-]+}").Name("uploadBlob") + + // Mounting a blob in an image + v2Router.Path("/mountblob/{imagename:[a-z0-9-._/]+}/{sumtype:[a-z0-9_+-]+}/{sum:[a-fA-F0-9]{4,}}").Name("mountBlob") + + return router +} + +// APIVersion2 /v2/ +var v2HTTPRoutes = newV2RegistryRouter() + +func getV2URL(e *Endpoint, routeName string, vars map[string]string) (*url.URL, error) { + route := v2HTTPRoutes.Get(routeName) + if route == nil { + return nil, fmt.Errorf("unknown regisry v2 route name: %q", routeName) + } + + varReplace := make([]string, 0, len(vars)*2) + for key, val := range vars { + varReplace = append(varReplace, key, val) + } + + routePath, err := route.URLPath(varReplace...) + if err != nil { + return nil, fmt.Errorf("unable to make registry route %q with vars %v: %s", routeName, vars, err) + } + + return &url.URL{ + Scheme: e.URL.Scheme, + Host: e.URL.Host, + Path: routePath.Path, + }, nil +} + +// V2 Provenance POC + +func (r *Session) GetV2Version(token []string) (*RegistryInfo, error) { + routeURL, err := getV2URL(r.indexEndpoint, "version", nil) + if err != nil { + return nil, err + } + + method := "GET" + log.Debugf("[registry] Calling %q %s", method, routeURL.String()) + + req, err := r.reqFactory.NewRequest(method, routeURL.String(), nil) + if err != nil { + return nil, err + } + setTokenAuth(req, token) + res, _, err := r.doRequest(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != 200 { + return nil, utils.NewHTTPRequestError(fmt.Sprintf("Server error: %d fetching Version", res.StatusCode), res) + } + + decoder := json.NewDecoder(res.Body) + versionInfo := new(RegistryInfo) + + err = decoder.Decode(versionInfo) + if err != nil { + return nil, fmt.Errorf("unable to decode GetV2Version JSON response: %s", err) + } + + return versionInfo, nil +} + +// +// 1) Check if TarSum of each layer exists /v2/ +// 1.a) if 200, continue +// 1.b) if 300, then push the +// 1.c) if anything else, err +// 2) PUT the created/signed manifest +// +func (r *Session) GetV2ImageManifest(imageName, tagName string, token []string) ([]byte, error) { + vars := map[string]string{ + "imagename": imageName, + "tagname": tagName, + } + + routeURL, err := getV2URL(r.indexEndpoint, "manifests", vars) + if err != nil { + return nil, err + } + + method := "GET" + log.Debugf("[registry] Calling %q %s", method, routeURL.String()) + + req, err := r.reqFactory.NewRequest(method, routeURL.String(), nil) + if err != nil { + return nil, err + } + setTokenAuth(req, token) + res, _, err := r.doRequest(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != 200 { + if res.StatusCode == 401 { + return nil, errLoginRequired + } else if res.StatusCode == 404 { + return nil, ErrDoesNotExist + } + return nil, utils.NewHTTPRequestError(fmt.Sprintf("Server error: %d trying to fetch for %s:%s", res.StatusCode, imageName, tagName), res) + } + + buf, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("Error while reading the http response: %s", err) + } + return buf, nil +} + +// - Succeeded to mount for this image scope +// - Failed with no error (So continue to Push the Blob) +// - Failed with error +func (r *Session) PostV2ImageMountBlob(imageName, sumType, sum string, token []string) (bool, error) { + vars := map[string]string{ + "imagename": imageName, + "sumtype": sumType, + "sum": sum, + } + + routeURL, err := getV2URL(r.indexEndpoint, "mountBlob", vars) + if err != nil { + return false, err + } + + method := "POST" + log.Debugf("[registry] Calling %q %s", method, routeURL.String()) + + req, err := r.reqFactory.NewRequest(method, routeURL.String(), nil) + if err != nil { + return false, err + } + setTokenAuth(req, token) + res, _, err := r.doRequest(req) + if err != nil { + return false, err + } + res.Body.Close() // close early, since we're not needing a body on this call .. yet? + switch res.StatusCode { + case 200: + // return something indicating no push needed + return true, nil + case 300: + // return something indicating blob push needed + return false, nil + } + return false, fmt.Errorf("Failed to mount %q - %s:%s : %d", imageName, sumType, sum, res.StatusCode) +} + +func (r *Session) GetV2ImageBlob(imageName, sumType, sum string, blobWrtr io.Writer, token []string) error { + vars := map[string]string{ + "imagename": imageName, + "sumtype": sumType, + "sum": sum, + } + + routeURL, err := getV2URL(r.indexEndpoint, "downloadBlob", vars) + if err != nil { + return err + } + + method := "GET" + log.Debugf("[registry] Calling %q %s", method, routeURL.String()) + req, err := r.reqFactory.NewRequest(method, routeURL.String(), nil) + if err != nil { + return err + } + setTokenAuth(req, token) + res, _, err := r.doRequest(req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != 200 { + if res.StatusCode == 401 { + return errLoginRequired + } + return utils.NewHTTPRequestError(fmt.Sprintf("Server error: %d trying to pull %s blob", res.StatusCode, imageName), res) + } + + _, err = io.Copy(blobWrtr, res.Body) + return err +} + +func (r *Session) GetV2ImageBlobReader(imageName, sumType, sum string, token []string) (io.ReadCloser, int64, error) { + vars := map[string]string{ + "imagename": imageName, + "sumtype": sumType, + "sum": sum, + } + + routeURL, err := getV2URL(r.indexEndpoint, "downloadBlob", vars) + if err != nil { + return nil, 0, err + } + + method := "GET" + log.Debugf("[registry] Calling %q %s", method, routeURL.String()) + req, err := r.reqFactory.NewRequest(method, routeURL.String(), nil) + if err != nil { + return nil, 0, err + } + setTokenAuth(req, token) + res, _, err := r.doRequest(req) + if err != nil { + return nil, 0, err + } + if res.StatusCode != 200 { + if res.StatusCode == 401 { + return nil, 0, errLoginRequired + } + return nil, 0, utils.NewHTTPRequestError(fmt.Sprintf("Server error: %d trying to pull %s blob", res.StatusCode, imageName), res) + } + lenStr := res.Header.Get("Content-Length") + l, err := strconv.ParseInt(lenStr, 10, 64) + if err != nil { + return nil, 0, err + } + + return res.Body, l, err +} + +// Push the image to the server for storage. +// 'layer' is an uncompressed reader of the blob to be pushed. +// The server will generate it's own checksum calculation. +func (r *Session) PutV2ImageBlob(imageName, sumType string, blobRdr io.Reader, token []string) (serverChecksum string, err error) { + vars := map[string]string{ + "imagename": imageName, + "sumtype": sumType, + } + + routeURL, err := getV2URL(r.indexEndpoint, "uploadBlob", vars) + if err != nil { + return "", err + } + + method := "PUT" + log.Debugf("[registry] Calling %q %s", method, routeURL.String()) + req, err := r.reqFactory.NewRequest(method, routeURL.String(), blobRdr) + if err != nil { + return "", err + } + setTokenAuth(req, token) + res, _, err := r.doRequest(req) + if err != nil { + return "", err + } + defer res.Body.Close() + if res.StatusCode != 201 { + if res.StatusCode == 401 { + return "", errLoginRequired + } + return "", utils.NewHTTPRequestError(fmt.Sprintf("Server error: %d trying to push %s blob", res.StatusCode, imageName), res) + } + + type sumReturn struct { + Checksum string `json:"checksum"` + } + + decoder := json.NewDecoder(res.Body) + var sumInfo sumReturn + + err = decoder.Decode(&sumInfo) + if err != nil { + return "", fmt.Errorf("unable to decode PutV2ImageBlob JSON response: %s", err) + } + + // XXX this is a json struct from the registry, with its checksum + return sumInfo.Checksum, nil +} + +// Finally Push the (signed) manifest of the blobs we've just pushed +func (r *Session) PutV2ImageManifest(imageName, tagName string, manifestRdr io.Reader, token []string) error { + vars := map[string]string{ + "imagename": imageName, + "tagname": tagName, + } + + routeURL, err := getV2URL(r.indexEndpoint, "manifests", vars) + if err != nil { + return err + } + + method := "PUT" + log.Debugf("[registry] Calling %q %s", method, routeURL.String()) + req, err := r.reqFactory.NewRequest(method, routeURL.String(), manifestRdr) + if err != nil { + return err + } + setTokenAuth(req, token) + res, _, err := r.doRequest(req) + if err != nil { + return err + } + res.Body.Close() + if res.StatusCode != 201 { + if res.StatusCode == 401 { + return errLoginRequired + } + return utils.NewHTTPRequestError(fmt.Sprintf("Server error: %d trying to push %s:%s manifest", res.StatusCode, imageName, tagName), res) + } + + return nil +} + +// Given a repository name, returns a json array of string tags +func (r *Session) GetV2RemoteTags(imageName string, token []string) ([]string, error) { + vars := map[string]string{ + "imagename": imageName, + } + + routeURL, err := getV2URL(r.indexEndpoint, "tags", vars) + if err != nil { + return nil, err + } + + method := "GET" + log.Debugf("[registry] Calling %q %s", method, routeURL.String()) + + req, err := r.reqFactory.NewRequest(method, routeURL.String(), nil) + if err != nil { + return nil, err + } + setTokenAuth(req, token) + res, _, err := r.doRequest(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != 200 { + if res.StatusCode == 401 { + return nil, errLoginRequired + } else if res.StatusCode == 404 { + return nil, ErrDoesNotExist + } + return nil, utils.NewHTTPRequestError(fmt.Sprintf("Server error: %d trying to fetch for %s", res.StatusCode, imageName), res) + } + + decoder := json.NewDecoder(res.Body) + var tags []string + err = decoder.Decode(&tags) + if err != nil { + return nil, fmt.Errorf("Error while decoding the http response: %s", err) + } + return tags, nil +} diff --git a/registry/types.go b/registry/types.go index 70d55e42fd..2ba5af0da4 100644 --- a/registry/types.go +++ b/registry/types.go @@ -31,3 +31,29 @@ type RegistryInfo struct { Version string `json:"version"` Standalone bool `json:"standalone"` } + +type ManifestData struct { + Name string `json:"name"` + Tag string `json:"tag"` + Architecture string `json:"architecture"` + BlobSums []string `json:"blobSums"` + History []string `json:"history"` + SchemaVersion int `json:"schemaVersion"` +} + +type APIVersion int + +func (av APIVersion) String() string { + return apiVersions[av] +} + +var DefaultAPIVersion APIVersion = APIVersion1 +var apiVersions = map[APIVersion]string{ + 1: "v1", + 2: "v2", +} + +const ( + APIVersion1 = iota + 1 + APIVersion2 +) diff --git a/trust/service.go b/trust/service.go new file mode 100644 index 0000000000..c056ac7191 --- /dev/null +++ b/trust/service.go @@ -0,0 +1,74 @@ +package trust + +import ( + "fmt" + "time" + + "github.com/docker/docker/engine" + "github.com/docker/docker/pkg/log" + "github.com/docker/libtrust" +) + +func (t *TrustStore) Install(eng *engine.Engine) error { + for name, handler := range map[string]engine.Handler{ + "trust_key_check": t.CmdCheckKey, + "trust_update_base": t.CmdUpdateBase, + } { + if err := eng.Register(name, handler); err != nil { + return fmt.Errorf("Could not register %q: %v", name, err) + } + } + return nil +} + +func (t *TrustStore) CmdCheckKey(job *engine.Job) engine.Status { + if n := len(job.Args); n != 1 { + return job.Errorf("Usage: %s NAMESPACE", job.Name) + } + var ( + namespace = job.Args[0] + keyBytes = job.Getenv("PublicKey") + ) + + if keyBytes == "" { + return job.Errorf("Missing PublicKey") + } + pk, err := libtrust.UnmarshalPublicKeyJWK([]byte(keyBytes)) + if err != nil { + return job.Errorf("Error unmarshalling public key: %s", err) + } + + permission := uint16(job.GetenvInt("Permission")) + if permission == 0 { + permission = 0x03 + } + + t.RLock() + defer t.RUnlock() + if t.graph == nil { + job.Stdout.Write([]byte("no graph")) + return engine.StatusOK + } + + // Check if any expired grants + verified, err := t.graph.Verify(pk, namespace, permission) + if err != nil { + return job.Errorf("Error verifying key to namespace: %s", namespace) + } + if !verified { + log.Debugf("Verification failed for %s using key %s", namespace, pk.KeyID()) + job.Stdout.Write([]byte("not verified")) + } else if t.expiration.Before(time.Now()) { + job.Stdout.Write([]byte("expired")) + } else { + job.Stdout.Write([]byte("verified")) + } + + return engine.StatusOK +} + +func (t *TrustStore) CmdUpdateBase(job *engine.Job) engine.Status { + t.fetch() + + return engine.StatusOK +} diff --git a/trust/trusts.go b/trust/trusts.go new file mode 100644 index 0000000000..a3c0f5f548 --- /dev/null +++ b/trust/trusts.go @@ -0,0 +1,199 @@ +package trust + +import ( + "crypto/x509" + "errors" + "io/ioutil" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "sync" + "time" + + "github.com/docker/docker/pkg/log" + "github.com/docker/libtrust/trustgraph" +) + +type TrustStore struct { + path string + caPool *x509.CertPool + graph trustgraph.TrustGraph + expiration time.Time + fetcher *time.Timer + fetchTime time.Duration + autofetch bool + httpClient *http.Client + baseEndpoints map[string]*url.URL + + sync.RWMutex +} + +// defaultFetchtime represents the starting duration to wait between +// fetching sections of the graph. Unsuccessful fetches should +// increase time between fetching. +const defaultFetchtime = 45 * time.Second + +var baseEndpoints = map[string]string{"official": "https://dvjy3tqbc323p.cloudfront.net/trust/official.json"} + +func NewTrustStore(path string) (*TrustStore, error) { + abspath, err := filepath.Abs(path) + if err != nil { + return nil, err + } + + // Create base graph url map + endpoints := map[string]*url.URL{} + for name, endpoint := range baseEndpoints { + u, err := url.Parse(endpoint) + if err != nil { + return nil, err + } + endpoints[name] = u + } + + // Load grant files + t := &TrustStore{ + path: abspath, + caPool: nil, + httpClient: &http.Client{}, + fetchTime: time.Millisecond, + baseEndpoints: endpoints, + } + + err = t.reload() + if err != nil { + return nil, err + } + + return t, nil +} + +func (t *TrustStore) reload() error { + t.Lock() + defer t.Unlock() + + matches, err := filepath.Glob(filepath.Join(t.path, "*.json")) + if err != nil { + return err + } + statements := make([]*trustgraph.Statement, len(matches)) + for i, match := range matches { + f, err := os.Open(match) + if err != nil { + return err + } + statements[i], err = trustgraph.LoadStatement(f, nil) + if err != nil { + f.Close() + return err + } + f.Close() + } + if len(statements) == 0 { + if t.autofetch { + log.Debugf("No grants, fetching") + t.fetcher = time.AfterFunc(t.fetchTime, t.fetch) + } + return nil + } + + grants, expiration, err := trustgraph.CollapseStatements(statements, true) + if err != nil { + return err + } + + t.expiration = expiration + t.graph = trustgraph.NewMemoryGraph(grants) + log.Debugf("Reloaded graph with %d grants expiring at %s", len(grants), expiration) + + if t.autofetch { + nextFetch := expiration.Sub(time.Now()) + if nextFetch < 0 { + nextFetch = defaultFetchtime + } else { + nextFetch = time.Duration(0.8 * (float64)(nextFetch)) + } + t.fetcher = time.AfterFunc(nextFetch, t.fetch) + } + + return nil +} + +func (t *TrustStore) fetchBaseGraph(u *url.URL) (*trustgraph.Statement, error) { + req := &http.Request{ + Method: "GET", + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Body: nil, + Host: u.Host, + } + + resp, err := t.httpClient.Do(req) + if err != nil { + return nil, err + } + if resp.StatusCode == 404 { + return nil, errors.New("base graph does not exist") + } + + defer resp.Body.Close() + + return trustgraph.LoadStatement(resp.Body, t.caPool) +} + +// fetch retrieves updated base graphs. This function cannot error, it +// should only log errors +func (t *TrustStore) fetch() { + t.Lock() + defer t.Unlock() + + if t.autofetch && t.fetcher == nil { + // Do nothing ?? + return + } + + fetchCount := 0 + for bg, ep := range t.baseEndpoints { + statement, err := t.fetchBaseGraph(ep) + if err != nil { + log.Infof("Trust graph fetch failed: %s", err) + continue + } + b, err := statement.Bytes() + if err != nil { + log.Infof("Bad trust graph statement: %s", err) + continue + } + // TODO check if value differs + err = ioutil.WriteFile(path.Join(t.path, bg+".json"), b, 0600) + if err != nil { + log.Infof("Error writing trust graph statement: %s", err) + } + fetchCount++ + } + log.Debugf("Fetched %d base graphs at %s", fetchCount, time.Now()) + + if fetchCount > 0 { + go func() { + err := t.reload() + if err != nil { + // TODO log + log.Infof("Reload of trust graph failed: %s", err) + } + }() + t.fetchTime = defaultFetchtime + t.fetcher = nil + } else if t.autofetch { + maxTime := 10 * defaultFetchtime + t.fetchTime = time.Duration(1.5 * (float64)(t.fetchTime+time.Second)) + if t.fetchTime > maxTime { + t.fetchTime = maxTime + } + t.fetcher = time.AfterFunc(t.fetchTime, t.fetch) + } +} diff --git a/vendor/src/github.com/docker/libtrust/trustgraph/graph.go b/vendor/src/github.com/docker/libtrust/trustgraph/graph.go new file mode 100644 index 0000000000..72b0fc3664 --- /dev/null +++ b/vendor/src/github.com/docker/libtrust/trustgraph/graph.go @@ -0,0 +1,50 @@ +package trustgraph + +import "github.com/docker/libtrust" + +// TrustGraph represents a graph of authorization mapping +// public keys to nodes and grants between nodes. +type TrustGraph interface { + // Verifies that the given public key is allowed to perform + // the given action on the given node according to the trust + // graph. + Verify(libtrust.PublicKey, string, uint16) (bool, error) + + // GetGrants returns an array of all grant chains which are used to + // allow the requested permission. + GetGrants(libtrust.PublicKey, string, uint16) ([][]*Grant, error) +} + +// Grant represents a transfer of permission from one part of the +// trust graph to another. This is the only way to delegate +// permission between two different sub trees in the graph. +type Grant struct { + // Subject is the namespace being granted + Subject string + + // Permissions is a bit map of permissions + Permission uint16 + + // Grantee represents the node being granted + // a permission scope. The grantee can be + // either a namespace item or a key id where namespace + // items will always start with a '/'. + Grantee string + + // statement represents the statement used to create + // this object. + statement *Statement +} + +// Permissions +// Read node 0x01 (can read node, no sub nodes) +// Write node 0x02 (can write to node object, cannot create subnodes) +// Read subtree 0x04 (delegates read to each sub node) +// Write subtree 0x08 (delegates write to each sub node, included create on the subject) +// +// Permission shortcuts +// ReadItem = 0x01 +// WriteItem = 0x03 +// ReadAccess = 0x07 +// WriteAccess = 0x0F +// Delegate = 0x0F diff --git a/vendor/src/github.com/docker/libtrust/trustgraph/memory_graph.go b/vendor/src/github.com/docker/libtrust/trustgraph/memory_graph.go new file mode 100644 index 0000000000..247bfa7aa6 --- /dev/null +++ b/vendor/src/github.com/docker/libtrust/trustgraph/memory_graph.go @@ -0,0 +1,133 @@ +package trustgraph + +import ( + "strings" + + "github.com/docker/libtrust" +) + +type grantNode struct { + grants []*Grant + children map[string]*grantNode +} + +type memoryGraph struct { + roots map[string]*grantNode +} + +func newGrantNode() *grantNode { + return &grantNode{ + grants: []*Grant{}, + children: map[string]*grantNode{}, + } +} + +// NewMemoryGraph returns a new in memory trust graph created from +// a static list of grants. This graph is immutable after creation +// and any alterations should create a new instance. +func NewMemoryGraph(grants []*Grant) TrustGraph { + roots := map[string]*grantNode{} + for _, grant := range grants { + parts := strings.Split(grant.Grantee, "/") + nodes := roots + var node *grantNode + var nodeOk bool + for _, part := range parts { + node, nodeOk = nodes[part] + if !nodeOk { + node = newGrantNode() + nodes[part] = node + } + if part != "" { + node.grants = append(node.grants, grant) + } + nodes = node.children + } + } + return &memoryGraph{roots} +} + +func (g *memoryGraph) getGrants(name string) []*Grant { + nameParts := strings.Split(name, "/") + nodes := g.roots + var node *grantNode + var nodeOk bool + for _, part := range nameParts { + node, nodeOk = nodes[part] + if !nodeOk { + return nil + } + nodes = node.children + } + return node.grants +} + +func isSubName(name, sub string) bool { + if strings.HasPrefix(name, sub) { + if len(name) == len(sub) || name[len(sub)] == '/' { + return true + } + } + return false +} + +type walkFunc func(*Grant, []*Grant) bool + +func foundWalkFunc(*Grant, []*Grant) bool { + return true +} + +func (g *memoryGraph) walkGrants(start, target string, permission uint16, f walkFunc, chain []*Grant, visited map[*Grant]bool, collect bool) bool { + if visited == nil { + visited = map[*Grant]bool{} + } + grants := g.getGrants(start) + subGrants := make([]*Grant, 0, len(grants)) + for _, grant := range grants { + if visited[grant] { + continue + } + visited[grant] = true + if grant.Permission&permission == permission { + if isSubName(target, grant.Subject) { + if f(grant, chain) { + return true + } + } else { + subGrants = append(subGrants, grant) + } + } + } + for _, grant := range subGrants { + var chainCopy []*Grant + if collect { + chainCopy = make([]*Grant, len(chain)+1) + copy(chainCopy, chain) + chainCopy[len(chainCopy)-1] = grant + } else { + chainCopy = nil + } + + if g.walkGrants(grant.Subject, target, permission, f, chainCopy, visited, collect) { + return true + } + } + return false +} + +func (g *memoryGraph) Verify(key libtrust.PublicKey, node string, permission uint16) (bool, error) { + return g.walkGrants(key.KeyID(), node, permission, foundWalkFunc, nil, nil, false), nil +} + +func (g *memoryGraph) GetGrants(key libtrust.PublicKey, node string, permission uint16) ([][]*Grant, error) { + grants := [][]*Grant{} + collect := func(grant *Grant, chain []*Grant) bool { + grantChain := make([]*Grant, len(chain)+1) + copy(grantChain, chain) + grantChain[len(grantChain)-1] = grant + grants = append(grants, grantChain) + return false + } + g.walkGrants(key.KeyID(), node, permission, collect, nil, nil, true) + return grants, nil +} diff --git a/vendor/src/github.com/docker/libtrust/trustgraph/memory_graph_test.go b/vendor/src/github.com/docker/libtrust/trustgraph/memory_graph_test.go new file mode 100644 index 0000000000..49fd0f3b54 --- /dev/null +++ b/vendor/src/github.com/docker/libtrust/trustgraph/memory_graph_test.go @@ -0,0 +1,174 @@ +package trustgraph + +import ( + "fmt" + "testing" + + "github.com/docker/libtrust" +) + +func createTestKeysAndGrants(count int) ([]*Grant, []libtrust.PrivateKey) { + grants := make([]*Grant, count) + keys := make([]libtrust.PrivateKey, count) + for i := 0; i < count; i++ { + pk, err := libtrust.GenerateECP256PrivateKey() + if err != nil { + panic(err) + } + grant := &Grant{ + Subject: fmt.Sprintf("/user-%d", i+1), + Permission: 0x0f, + Grantee: pk.KeyID(), + } + keys[i] = pk + grants[i] = grant + } + return grants, keys +} + +func testVerified(t *testing.T, g TrustGraph, k libtrust.PublicKey, keyName, target string, permission uint16) { + if ok, err := g.Verify(k, target, permission); err != nil { + t.Fatalf("Unexpected error during verification: %s", err) + } else if !ok { + t.Errorf("key failed verification\n\tKey: %s(%s)\n\tNamespace: %s", keyName, k.KeyID(), target) + } +} + +func testNotVerified(t *testing.T, g TrustGraph, k libtrust.PublicKey, keyName, target string, permission uint16) { + if ok, err := g.Verify(k, target, permission); err != nil { + t.Fatalf("Unexpected error during verification: %s", err) + } else if ok { + t.Errorf("key should have failed verification\n\tKey: %s(%s)\n\tNamespace: %s", keyName, k.KeyID(), target) + } +} + +func TestVerify(t *testing.T) { + grants, keys := createTestKeysAndGrants(4) + extraGrants := make([]*Grant, 3) + extraGrants[0] = &Grant{ + Subject: "/user-3", + Permission: 0x0f, + Grantee: "/user-2", + } + extraGrants[1] = &Grant{ + Subject: "/user-3/sub-project", + Permission: 0x0f, + Grantee: "/user-4", + } + extraGrants[2] = &Grant{ + Subject: "/user-4", + Permission: 0x07, + Grantee: "/user-1", + } + grants = append(grants, extraGrants...) + + g := NewMemoryGraph(grants) + + testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f) + testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1/some-project/sub-value", 0x0f) + testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-4", 0x07) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2/", 0x0f) + testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3/sub-value", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/sub-value", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/", 0x0f) + testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project", 0x0f) + testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project/app", 0x0f) + testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-4", 0x0f) + + testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-2", 0x0f) + testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-3/sub-value", 0x0f) + testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-4", 0x0f) + testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-1/", 0x0f) + testNotVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-2", 0x0f) + testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-4", 0x0f) + testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3", 0x0f) +} + +func TestCircularWalk(t *testing.T) { + grants, keys := createTestKeysAndGrants(3) + user1Grant := &Grant{ + Subject: "/user-2", + Permission: 0x0f, + Grantee: "/user-1", + } + user2Grant := &Grant{ + Subject: "/user-1", + Permission: 0x0f, + Grantee: "/user-2", + } + grants = append(grants, user1Grant, user2Grant) + + g := NewMemoryGraph(grants) + + testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f) + testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-2", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-1", 0x0f) + testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3", 0x0f) + + testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-3", 0x0f) + testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f) +} + +func assertGrantSame(t *testing.T, actual, expected *Grant) { + if actual != expected { + t.Fatalf("Unexpected grant retrieved\n\tExpected: %v\n\tActual: %v", expected, actual) + } +} + +func TestGetGrants(t *testing.T) { + grants, keys := createTestKeysAndGrants(5) + extraGrants := make([]*Grant, 4) + extraGrants[0] = &Grant{ + Subject: "/user-3/friend-project", + Permission: 0x0f, + Grantee: "/user-2/friends", + } + extraGrants[1] = &Grant{ + Subject: "/user-3/sub-project", + Permission: 0x0f, + Grantee: "/user-4", + } + extraGrants[2] = &Grant{ + Subject: "/user-2/friends", + Permission: 0x0f, + Grantee: "/user-5/fun-project", + } + extraGrants[3] = &Grant{ + Subject: "/user-5/fun-project", + Permission: 0x0f, + Grantee: "/user-1", + } + grants = append(grants, extraGrants...) + + g := NewMemoryGraph(grants) + + grantChains, err := g.GetGrants(keys[3], "/user-3/sub-project/specific-app", 0x0f) + if err != nil { + t.Fatalf("Error getting grants: %s", err) + } + if len(grantChains) != 1 { + t.Fatalf("Expected number of grant chains returned, expected %d, received %d", 1, len(grantChains)) + } + if len(grantChains[0]) != 2 { + t.Fatalf("Unexpected number of grants retrieved\n\tExpected: %d\n\tActual: %d", 2, len(grantChains[0])) + } + assertGrantSame(t, grantChains[0][0], grants[3]) + assertGrantSame(t, grantChains[0][1], extraGrants[1]) + + grantChains, err = g.GetGrants(keys[0], "/user-3/friend-project/fun-app", 0x0f) + if err != nil { + t.Fatalf("Error getting grants: %s", err) + } + if len(grantChains) != 1 { + t.Fatalf("Expected number of grant chains returned, expected %d, received %d", 1, len(grantChains)) + } + if len(grantChains[0]) != 4 { + t.Fatalf("Unexpected number of grants retrieved\n\tExpected: %d\n\tActual: %d", 2, len(grantChains[0])) + } + assertGrantSame(t, grantChains[0][0], grants[0]) + assertGrantSame(t, grantChains[0][1], extraGrants[3]) + assertGrantSame(t, grantChains[0][2], extraGrants[2]) + assertGrantSame(t, grantChains[0][3], extraGrants[0]) +} diff --git a/vendor/src/github.com/docker/libtrust/trustgraph/statement.go b/vendor/src/github.com/docker/libtrust/trustgraph/statement.go new file mode 100644 index 0000000000..7a74b553cd --- /dev/null +++ b/vendor/src/github.com/docker/libtrust/trustgraph/statement.go @@ -0,0 +1,227 @@ +package trustgraph + +import ( + "crypto/x509" + "encoding/json" + "io" + "io/ioutil" + "sort" + "strings" + "time" + + "github.com/docker/libtrust" +) + +type jsonGrant struct { + Subject string `json:"subject"` + Permission uint16 `json:"permission"` + Grantee string `json:"grantee"` +} + +type jsonRevocation struct { + Subject string `json:"subject"` + Revocation uint16 `json:"revocation"` + Grantee string `json:"grantee"` +} + +type jsonStatement struct { + Revocations []*jsonRevocation `json:"revocations"` + Grants []*jsonGrant `json:"grants"` + Expiration time.Time `json:"expiration"` + IssuedAt time.Time `json:"issuedAt"` +} + +func (g *jsonGrant) Grant(statement *Statement) *Grant { + return &Grant{ + Subject: g.Subject, + Permission: g.Permission, + Grantee: g.Grantee, + statement: statement, + } +} + +// Statement represents a set of grants made from a verifiable +// authority. A statement has an expiration associated with it +// set by the authority. +type Statement struct { + jsonStatement + + signature *libtrust.JSONSignature +} + +// IsExpired returns whether the statement has expired +func (s *Statement) IsExpired() bool { + return s.Expiration.Before(time.Now().Add(-10 * time.Second)) +} + +// Bytes returns an indented json representation of the statement +// in a byte array. This value can be written to a file or stream +// without alteration. +func (s *Statement) Bytes() ([]byte, error) { + return s.signature.PrettySignature("signatures") +} + +// LoadStatement loads and verifies a statement from an input stream. +func LoadStatement(r io.Reader, authority *x509.CertPool) (*Statement, error) { + b, err := ioutil.ReadAll(r) + if err != nil { + return nil, err + } + js, err := libtrust.ParsePrettySignature(b, "signatures") + if err != nil { + return nil, err + } + payload, err := js.Payload() + if err != nil { + return nil, err + } + var statement Statement + err = json.Unmarshal(payload, &statement.jsonStatement) + if err != nil { + return nil, err + } + + if authority == nil { + _, err = js.Verify() + if err != nil { + return nil, err + } + } else { + _, err = js.VerifyChains(authority) + if err != nil { + return nil, err + } + } + statement.signature = js + + return &statement, nil +} + +// CreateStatements creates and signs a statement from a stream of grants +// and revocations in a JSON array. +func CreateStatement(grants, revocations io.Reader, expiration time.Duration, key libtrust.PrivateKey, chain []*x509.Certificate) (*Statement, error) { + var statement Statement + err := json.NewDecoder(grants).Decode(&statement.jsonStatement.Grants) + if err != nil { + return nil, err + } + err = json.NewDecoder(revocations).Decode(&statement.jsonStatement.Revocations) + if err != nil { + return nil, err + } + statement.jsonStatement.Expiration = time.Now().UTC().Add(expiration) + statement.jsonStatement.IssuedAt = time.Now().UTC() + + b, err := json.MarshalIndent(&statement.jsonStatement, "", " ") + if err != nil { + return nil, err + } + + statement.signature, err = libtrust.NewJSONSignature(b) + if err != nil { + return nil, err + } + err = statement.signature.SignWithChain(key, chain) + if err != nil { + return nil, err + } + + return &statement, nil +} + +type statementList []*Statement + +func (s statementList) Len() int { + return len(s) +} + +func (s statementList) Less(i, j int) bool { + return s[i].IssuedAt.Before(s[j].IssuedAt) +} + +func (s statementList) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +// CollapseStatements returns a single list of the valid statements as well as the +// time when the next grant will expire. +func CollapseStatements(statements []*Statement, useExpired bool) ([]*Grant, time.Time, error) { + sorted := make(statementList, 0, len(statements)) + for _, statement := range statements { + if useExpired || !statement.IsExpired() { + sorted = append(sorted, statement) + } + } + sort.Sort(sorted) + + var minExpired time.Time + var grantCount int + roots := map[string]*grantNode{} + for i, statement := range sorted { + if statement.Expiration.Before(minExpired) || i == 0 { + minExpired = statement.Expiration + } + for _, grant := range statement.Grants { + parts := strings.Split(grant.Grantee, "/") + nodes := roots + g := grant.Grant(statement) + grantCount = grantCount + 1 + + for _, part := range parts { + node, nodeOk := nodes[part] + if !nodeOk { + node = newGrantNode() + nodes[part] = node + } + node.grants = append(node.grants, g) + nodes = node.children + } + } + + for _, revocation := range statement.Revocations { + parts := strings.Split(revocation.Grantee, "/") + nodes := roots + + var node *grantNode + var nodeOk bool + for _, part := range parts { + node, nodeOk = nodes[part] + if !nodeOk { + break + } + nodes = node.children + } + if node != nil { + for _, grant := range node.grants { + if isSubName(grant.Subject, revocation.Subject) { + grant.Permission = grant.Permission &^ revocation.Revocation + } + } + } + } + } + + retGrants := make([]*Grant, 0, grantCount) + for _, rootNodes := range roots { + retGrants = append(retGrants, rootNodes.grants...) + } + + return retGrants, minExpired, nil +} + +// FilterStatements filters the statements to statements including the given grants. +func FilterStatements(grants []*Grant) ([]*Statement, error) { + statements := map[*Statement]bool{} + for _, grant := range grants { + if grant.statement != nil { + statements[grant.statement] = true + } + } + retStatements := make([]*Statement, len(statements)) + var i int + for statement := range statements { + retStatements[i] = statement + i++ + } + return retStatements, nil +} diff --git a/vendor/src/github.com/docker/libtrust/trustgraph/statement_test.go b/vendor/src/github.com/docker/libtrust/trustgraph/statement_test.go new file mode 100644 index 0000000000..d9c3c1a1ea --- /dev/null +++ b/vendor/src/github.com/docker/libtrust/trustgraph/statement_test.go @@ -0,0 +1,417 @@ +package trustgraph + +import ( + "bytes" + "crypto/x509" + "encoding/json" + "testing" + "time" + + "github.com/docker/libtrust" + "github.com/docker/libtrust/testutil" +) + +const testStatementExpiration = time.Hour * 5 + +func generateStatement(grants []*Grant, key libtrust.PrivateKey, chain []*x509.Certificate) (*Statement, error) { + var statement Statement + + statement.Grants = make([]*jsonGrant, len(grants)) + for i, grant := range grants { + statement.Grants[i] = &jsonGrant{ + Subject: grant.Subject, + Permission: grant.Permission, + Grantee: grant.Grantee, + } + } + statement.IssuedAt = time.Now() + statement.Expiration = time.Now().Add(testStatementExpiration) + statement.Revocations = make([]*jsonRevocation, 0) + + marshalled, err := json.MarshalIndent(statement.jsonStatement, "", " ") + if err != nil { + return nil, err + } + + sig, err := libtrust.NewJSONSignature(marshalled) + if err != nil { + return nil, err + } + err = sig.SignWithChain(key, chain) + if err != nil { + return nil, err + } + statement.signature = sig + + return &statement, nil +} + +func generateTrustChain(t *testing.T, chainLen int) (libtrust.PrivateKey, *x509.CertPool, []*x509.Certificate) { + caKey, err := libtrust.GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generating key: %s", err) + } + ca, err := testutil.GenerateTrustCA(caKey.CryptoPublicKey(), caKey.CryptoPrivateKey()) + if err != nil { + t.Fatalf("Error generating ca: %s", err) + } + + parent := ca + parentKey := caKey + chain := make([]*x509.Certificate, chainLen) + for i := chainLen - 1; i > 0; i-- { + intermediatekey, err := libtrust.GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generate key: %s", err) + } + chain[i], err = testutil.GenerateIntermediate(intermediatekey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent) + if err != nil { + t.Fatalf("Error generating intermdiate certificate: %s", err) + } + parent = chain[i] + parentKey = intermediatekey + } + trustKey, err := libtrust.GenerateECP256PrivateKey() + if err != nil { + t.Fatalf("Error generate key: %s", err) + } + chain[0], err = testutil.GenerateTrustCert(trustKey.CryptoPublicKey(), parentKey.CryptoPrivateKey(), parent) + if err != nil { + t.Fatalf("Error generate trust cert: %s", err) + } + + caPool := x509.NewCertPool() + caPool.AddCert(ca) + + return trustKey, caPool, chain +} + +func TestLoadStatement(t *testing.T) { + grantCount := 4 + grants, _ := createTestKeysAndGrants(grantCount) + + trustKey, caPool, chain := generateTrustChain(t, 6) + + statement, err := generateStatement(grants, trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + + statementBytes, err := statement.Bytes() + if err != nil { + t.Fatalf("Error getting statement bytes: %s", err) + } + + s2, err := LoadStatement(bytes.NewReader(statementBytes), caPool) + if err != nil { + t.Fatalf("Error loading statement: %s", err) + } + if len(s2.Grants) != grantCount { + t.Fatalf("Unexpected grant length\n\tExpected: %d\n\tActual: %d", grantCount, len(s2.Grants)) + } + + pool := x509.NewCertPool() + _, err = LoadStatement(bytes.NewReader(statementBytes), pool) + if err == nil { + t.Fatalf("No error thrown verifying without an authority") + } else if _, ok := err.(x509.UnknownAuthorityError); !ok { + t.Fatalf("Unexpected error verifying without authority: %s", err) + } + + s2, err = LoadStatement(bytes.NewReader(statementBytes), nil) + if err != nil { + t.Fatalf("Error loading statement: %s", err) + } + if len(s2.Grants) != grantCount { + t.Fatalf("Unexpected grant length\n\tExpected: %d\n\tActual: %d", grantCount, len(s2.Grants)) + } + + badData := make([]byte, len(statementBytes)) + copy(badData, statementBytes) + badData[0] = '[' + _, err = LoadStatement(bytes.NewReader(badData), nil) + if err == nil { + t.Fatalf("No error thrown parsing bad json") + } + + alteredData := make([]byte, len(statementBytes)) + copy(alteredData, statementBytes) + alteredData[30] = '0' + _, err = LoadStatement(bytes.NewReader(alteredData), nil) + if err == nil { + t.Fatalf("No error thrown from bad data") + } +} + +func TestCollapseGrants(t *testing.T) { + grantCount := 8 + grants, keys := createTestKeysAndGrants(grantCount) + linkGrants := make([]*Grant, 4) + linkGrants[0] = &Grant{ + Subject: "/user-3", + Permission: 0x0f, + Grantee: "/user-2", + } + linkGrants[1] = &Grant{ + Subject: "/user-3/sub-project", + Permission: 0x0f, + Grantee: "/user-4", + } + linkGrants[2] = &Grant{ + Subject: "/user-6", + Permission: 0x0f, + Grantee: "/user-7", + } + linkGrants[3] = &Grant{ + Subject: "/user-6/sub-project/specific-app", + Permission: 0x0f, + Grantee: "/user-5", + } + trustKey, pool, chain := generateTrustChain(t, 3) + + statements := make([]*Statement, 3) + var err error + statements[0], err = generateStatement(grants[0:4], trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + statements[1], err = generateStatement(grants[4:], trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + statements[2], err = generateStatement(linkGrants, trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + + statementsCopy := make([]*Statement, len(statements)) + for i, statement := range statements { + b, err := statement.Bytes() + if err != nil { + t.Fatalf("Error getting statement bytes: %s", err) + } + verifiedStatement, err := LoadStatement(bytes.NewReader(b), pool) + if err != nil { + t.Fatalf("Error loading statement: %s", err) + } + // Force sort by reversing order + statementsCopy[len(statementsCopy)-i-1] = verifiedStatement + } + statements = statementsCopy + + collapsedGrants, expiration, err := CollapseStatements(statements, false) + if len(collapsedGrants) != 12 { + t.Fatalf("Unexpected number of grants\n\tExpected: %d\n\tActual: %s", 12, len(collapsedGrants)) + } + if expiration.After(time.Now().Add(time.Hour*5)) || expiration.Before(time.Now()) { + t.Fatalf("Unexpected expiration time: %s", expiration.String()) + } + g := NewMemoryGraph(collapsedGrants) + + testVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f) + testVerified(t, g, keys[2].PublicKey(), "user-key-3", "/user-3", 0x0f) + testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-4", 0x0f) + testVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-5", 0x0f) + testVerified(t, g, keys[5].PublicKey(), "user-key-6", "/user-6", 0x0f) + testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-7", 0x0f) + testVerified(t, g, keys[7].PublicKey(), "user-key-8", "/user-8", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3", 0x0f) + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-3/sub-project/specific-app", 0x0f) + testVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3/sub-project", 0x0f) + testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6", 0x0f) + testVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6/sub-project/specific-app", 0x0f) + testVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-6/sub-project/specific-app", 0x0f) + + testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-3", 0x0f) + testNotVerified(t, g, keys[3].PublicKey(), "user-key-4", "/user-6/sub-project", 0x0f) + testNotVerified(t, g, keys[4].PublicKey(), "user-key-5", "/user-6/sub-project", 0x0f) + + // Add revocation grant + statements = append(statements, &Statement{ + jsonStatement{ + IssuedAt: time.Now(), + Expiration: time.Now().Add(testStatementExpiration), + Grants: []*jsonGrant{}, + Revocations: []*jsonRevocation{ + &jsonRevocation{ + Subject: "/user-1", + Revocation: 0x0f, + Grantee: keys[0].KeyID(), + }, + &jsonRevocation{ + Subject: "/user-2", + Revocation: 0x08, + Grantee: keys[1].KeyID(), + }, + &jsonRevocation{ + Subject: "/user-6", + Revocation: 0x0f, + Grantee: "/user-7", + }, + &jsonRevocation{ + Subject: "/user-9", + Revocation: 0x0f, + Grantee: "/user-10", + }, + }, + }, + nil, + }) + + collapsedGrants, expiration, err = CollapseStatements(statements, false) + if len(collapsedGrants) != 12 { + t.Fatalf("Unexpected number of grants\n\tExpected: %d\n\tActual: %s", 12, len(collapsedGrants)) + } + if expiration.After(time.Now().Add(time.Hour*5)) || expiration.Before(time.Now()) { + t.Fatalf("Unexpected expiration time: %s", expiration.String()) + } + g = NewMemoryGraph(collapsedGrants) + + testNotVerified(t, g, keys[0].PublicKey(), "user-key-1", "/user-1", 0x0f) + testNotVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x0f) + testNotVerified(t, g, keys[6].PublicKey(), "user-key-7", "/user-6/sub-project/specific-app", 0x0f) + + testVerified(t, g, keys[1].PublicKey(), "user-key-2", "/user-2", 0x07) +} + +func TestFilterStatements(t *testing.T) { + grantCount := 8 + grants, keys := createTestKeysAndGrants(grantCount) + linkGrants := make([]*Grant, 3) + linkGrants[0] = &Grant{ + Subject: "/user-3", + Permission: 0x0f, + Grantee: "/user-2", + } + linkGrants[1] = &Grant{ + Subject: "/user-5", + Permission: 0x0f, + Grantee: "/user-4", + } + linkGrants[2] = &Grant{ + Subject: "/user-7", + Permission: 0x0f, + Grantee: "/user-6", + } + + trustKey, _, chain := generateTrustChain(t, 3) + + statements := make([]*Statement, 5) + var err error + statements[0], err = generateStatement(grants[0:2], trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + statements[1], err = generateStatement(grants[2:4], trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + statements[2], err = generateStatement(grants[4:6], trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + statements[3], err = generateStatement(grants[6:], trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + statements[4], err = generateStatement(linkGrants, trustKey, chain) + if err != nil { + t.Fatalf("Error generating statement: %s", err) + } + collapsed, _, err := CollapseStatements(statements, false) + if err != nil { + t.Fatalf("Error collapsing grants: %s", err) + } + + // Filter 1, all 5 statements + filter1, err := FilterStatements(collapsed) + if err != nil { + t.Fatalf("Error filtering statements: %s", err) + } + if len(filter1) != 5 { + t.Fatalf("Wrong number of statements, expected %d, received %d", 5, len(filter1)) + } + + // Filter 2, one statement + filter2, err := FilterStatements([]*Grant{collapsed[0]}) + if err != nil { + t.Fatalf("Error filtering statements: %s", err) + } + if len(filter2) != 1 { + t.Fatalf("Wrong number of statements, expected %d, received %d", 1, len(filter2)) + } + + // Filter 3, 2 statements, from graph lookup + g := NewMemoryGraph(collapsed) + lookupGrants, err := g.GetGrants(keys[1], "/user-3", 0x0f) + if err != nil { + t.Fatalf("Error looking up grants: %s", err) + } + if len(lookupGrants) != 1 { + t.Fatalf("Wrong numberof grant chains returned from lookup, expected %d, received %d", 1, len(lookupGrants)) + } + if len(lookupGrants[0]) != 2 { + t.Fatalf("Wrong number of grants looked up, expected %d, received %d", 2, len(lookupGrants)) + } + filter3, err := FilterStatements(lookupGrants[0]) + if err != nil { + t.Fatalf("Error filtering statements: %s", err) + } + if len(filter3) != 2 { + t.Fatalf("Wrong number of statements, expected %d, received %d", 2, len(filter3)) + } + +} + +func TestCreateStatement(t *testing.T) { + grantJSON := bytes.NewReader([]byte(`[ + { + "subject": "/user-2", + "permission": 15, + "grantee": "/user-1" + }, + { + "subject": "/user-7", + "permission": 1, + "grantee": "/user-9" + }, + { + "subject": "/user-3", + "permission": 15, + "grantee": "/user-2" + } +]`)) + revocationJSON := bytes.NewReader([]byte(`[ + { + "subject": "user-8", + "revocation": 12, + "grantee": "user-9" + } +]`)) + + trustKey, pool, chain := generateTrustChain(t, 3) + + statement, err := CreateStatement(grantJSON, revocationJSON, testStatementExpiration, trustKey, chain) + if err != nil { + t.Fatalf("Error creating statement: %s", err) + } + + b, err := statement.Bytes() + if err != nil { + t.Fatalf("Error retrieving bytes: %s", err) + } + + verified, err := LoadStatement(bytes.NewReader(b), pool) + if err != nil { + t.Fatalf("Error loading statement: %s", err) + } + + if len(verified.Grants) != 3 { + t.Errorf("Unexpected number of grants, expected %d, received %d", 3, len(verified.Grants)) + } + + if len(verified.Revocations) != 1 { + t.Errorf("Unexpected number of revocations, expected %d, received %d", 1, len(verified.Revocations)) + } +}