| // Copyright 2016 The Roughtime Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. */ |
| |
| // client is a somewhat featured Roughtime client. |
| package main |
| |
| import ( |
| "crypto/rand" |
| "encoding/binary" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "io" |
| "net" |
| "time" |
| |
| "math/big" |
| mathrand "math/rand" |
| |
| "golang.org/x/crypto/ed25519" |
| "roughtime.googlesource.com/go/client/monotime" |
| "roughtime.googlesource.com/go/config" |
| "roughtime.googlesource.com/go/protocol" |
| ) |
| |
| const ( |
| // defaultMaxRadius is the maximum radius that we'll accept from a |
| // server. |
| defaultMaxRadius = 10 * time.Second |
| |
| // defaultMaxDifference is the maximum difference in time between any |
| // sample from a server and the quorum-agreed time before we believe |
| // that the server might be misbehaving. |
| defaultMaxDifference = 60 * time.Second |
| |
| // defaultTimeout is the default maximum time that a server has to |
| // answer a query. |
| defaultTimeout = 2 * time.Second |
| |
| // defaultNumQueries is the default number of times we will try to |
| // query a server. |
| defaultNumQueries = 3 |
| ) |
| |
| // Client represents a Roughtime client and exposes a number of members that |
| // can be set in order to configure it. The zero value of a Client is always |
| // ready to use and will set sensible defaults. |
| type Client struct { |
| // Permutation returns a random permutation of [0‥n) that is used to |
| // query servers in a random order. If nil, a sensible default is used. |
| Permutation func(n int) []int |
| |
| // MaxRadiusUs is the maximum interval radius that will be accepted |
| // from a server. If zero, a sensible default is used. |
| MaxRadius time.Duration |
| |
| // MaxDifference is the maximum difference in time between any sample |
| // from a server and the quorum-agreed time before that sample is |
| // considered suspect. If zero, a sensible default is used. |
| MaxDifference time.Duration |
| |
| // QueryTimeout is the amount of time a server has to reply to a query. |
| // If zero, a sensible default will be used. |
| QueryTimeout time.Duration |
| |
| // NumQueries is the maximum number of times a query will be sent to a |
| // specific server before giving up. If <= zero, a sensible default |
| // will be used. |
| NumQueries int |
| |
| // now returns a monotonic duration from some unspecified epoch. If |
| // nil, the system monotonic time will be used. |
| nowFunc func() time.Duration |
| } |
| |
| func (c *Client) now() time.Duration { |
| if c.nowFunc != nil { |
| return c.nowFunc() |
| } |
| return monotime.Now() |
| } |
| |
| func (c *Client) permutation(n int) []int { |
| if c.Permutation != nil { |
| return c.Permutation(n) |
| } |
| |
| var randBuf [8]byte |
| if _, err := io.ReadFull(rand.Reader, randBuf[:]); err != nil { |
| panic(err) |
| } |
| |
| seed := binary.LittleEndian.Uint64(randBuf[:]) |
| rand := mathrand.New(mathrand.NewSource(int64(seed))) |
| |
| return rand.Perm(n) |
| } |
| |
| func (c *Client) maxRadius() time.Duration { |
| if c.MaxRadius != 0 { |
| return c.MaxRadius |
| } |
| |
| return defaultMaxRadius |
| } |
| |
| func (c *Client) maxDifference() time.Duration { |
| if c.MaxDifference != 0 { |
| return c.MaxDifference |
| } |
| |
| return defaultMaxDifference |
| } |
| |
| func (c *Client) queryTimeout() time.Duration { |
| if c.QueryTimeout != 0 { |
| return c.QueryTimeout |
| } |
| |
| return defaultTimeout |
| } |
| |
| func (c *Client) numQueries() int { |
| if c.NumQueries > 0 { |
| return c.NumQueries |
| } |
| |
| return defaultNumQueries |
| } |
| |
| // LoadChain loads a JSON-format chain from the given JSON data. |
| func LoadChain(jsonData []byte) (chain *config.Chain, err error) { |
| chain = new(config.Chain) |
| if err := json.Unmarshal(jsonData, chain); err != nil { |
| return nil, err |
| } |
| |
| for i, link := range chain.Links { |
| if link.PublicKeyType != "ed25519" { |
| return nil, fmt.Errorf("client: link #%d in chain file has unknown public key type %q", i, link.PublicKeyType) |
| } |
| |
| if l := len(link.PublicKey); l != ed25519.PublicKeySize { |
| return nil, fmt.Errorf("client: link #%d in chain file has bad public key of length %d", i, l) |
| } |
| |
| if l := len(link.NonceOrBlind); l != protocol.NonceSize { |
| return nil, fmt.Errorf("client: link #%d in chain file has bad nonce/blind of length %d", i, l) |
| } |
| |
| var nonce [protocol.NonceSize]byte |
| if i == 0 { |
| copy(nonce[:], link.NonceOrBlind[:]) |
| } else { |
| nonce = protocol.CalculateChainNonce(chain.Links[i-1].Reply, link.NonceOrBlind[:]) |
| } |
| |
| if _, _, err := protocol.VerifyReply(link.Reply, link.PublicKey, nonce); err != nil { |
| return nil, fmt.Errorf("client: failed to verify link #%d in chain file", i) |
| } |
| } |
| |
| return chain, nil |
| } |
| |
| // timeSample represents a time sample from the network. |
| type timeSample struct { |
| // server references the server that was queried. |
| server *config.Server |
| |
| // base is a monotonic clock sample that is taken at a time before the |
| // network could have answered the query. |
| base *big.Int |
| |
| // min is the minimum real-time (in Roughtime UTC microseconds) that |
| // could correspond to |base| (i.e. midpoint - radius). |
| min *big.Int |
| |
| // max is the maximum real-time (in Roughtime UTC microseconds) that |
| // could correspond to |base| (i.e. midpoint + radius + query time). |
| max *big.Int |
| } |
| |
| // midpoint returns the average of the min and max times. |
| func (s *timeSample) midpoint() *big.Int { |
| ret := new(big.Int).Add(s.min, s.max) |
| return ret.Rsh(ret, 1) |
| } |
| |
| // alignTo updates s so that its base value matches that from reference. |
| func (s *timeSample) alignTo(reference *timeSample) { |
| delta := new(big.Int).Sub(s.base, reference.base) |
| s.base.Sub(s.base, delta) |
| s.min.Sub(s.min, delta) |
| s.max.Sub(s.max, delta) |
| } |
| |
| // contains returns true iff p belongs to s |
| func (s *timeSample) contains(p *big.Int) bool { |
| return s.max.Cmp(p) >= 0 && s.min.Cmp(p) <= 0 |
| } |
| |
| // overlaps returns true iff s and other have any timespan in common. |
| func (s *timeSample) overlaps(other *timeSample) bool { |
| return s.max.Cmp(other.min) >= 0 && other.max.Cmp(s.min) >= 0 |
| } |
| |
| // query sends a request to s, appends it to chain, and returns the resulting |
| // timeSample. |
| func (c *Client) query(server *config.Server, chain *config.Chain) (*timeSample, error) { |
| var prevReply []byte |
| if len(chain.Links) > 0 { |
| prevReply = chain.Links[len(chain.Links)-1].Reply |
| } |
| |
| var baseTime, replyTime time.Duration |
| var reply []byte |
| var nonce, blind [protocol.NonceSize]byte |
| |
| for attempts := 0; attempts < c.numQueries(); attempts++ { |
| var request []byte |
| var err error |
| if nonce, blind, request, err = protocol.CreateRequest(rand.Reader, prevReply); err != nil { |
| return nil, err |
| } |
| if len(request) < protocol.MinRequestSize { |
| panic("internal error: bad request length") |
| } |
| |
| udpAddr, err := serverUDPAddr(server) |
| if err != nil { |
| panic(err) |
| } |
| |
| conn, err := net.DialUDP("udp", nil, udpAddr) |
| if err != nil { |
| return nil, err |
| } |
| |
| conn.SetReadDeadline(time.Now().Add(c.queryTimeout())) |
| baseTime = c.now() |
| conn.Write(request) |
| |
| var replyBytes [1024]byte |
| n, err := conn.Read(replyBytes[:]) |
| if err == nil { |
| replyTime = c.now() |
| reply = replyBytes[:n] |
| break |
| } |
| |
| if netErr, ok := err.(net.Error); ok { |
| if !netErr.Timeout() { |
| return nil, errors.New("client: error reading from UDP socket: " + err.Error()) |
| } |
| } |
| } |
| |
| if reply == nil { |
| return nil, fmt.Errorf("client: no reply from server %q", server.Name) |
| } |
| |
| if replyTime < baseTime { |
| panic("broken monotonic clock") |
| } |
| queryDuration := replyTime - baseTime |
| |
| midpoint, radius, err := protocol.VerifyReply(reply, server.PublicKey, nonce) |
| if err != nil { |
| return nil, err |
| } |
| |
| if time.Duration(radius)*time.Microsecond > c.maxRadius() { |
| return nil, fmt.Errorf("client: radius (%d) too large", radius) |
| } |
| |
| nonceOrBlind := blind[:] |
| if len(prevReply) == 0 { |
| nonceOrBlind = nonce[:] |
| } |
| |
| chain.Links = append(chain.Links, config.Link{ |
| PublicKeyType: "ed25519", |
| PublicKey: server.PublicKey, |
| NonceOrBlind: nonceOrBlind, |
| Reply: reply, |
| }) |
| |
| queryDurationBig := new(big.Int).SetInt64(int64(queryDuration/time.Microsecond)) |
| bigRadius := new(big.Int).SetUint64(uint64(radius)) |
| min := new(big.Int).SetUint64(midpoint) |
| min.Sub(min, bigRadius) |
| min.Sub(min, queryDurationBig) |
| |
| max := new(big.Int).SetUint64(midpoint) |
| max.Add(max, bigRadius) |
| |
| return &timeSample{ |
| server: server, |
| base: new(big.Int).SetInt64(int64(baseTime)), |
| min: min, |
| max: max, |
| }, nil |
| } |
| |
| func serverUDPAddr(server *config.Server) (*net.UDPAddr, error) { |
| for _, addr := range server.Addresses { |
| if addr.Protocol != "udp" { |
| continue |
| } |
| |
| return net.ResolveUDPAddr("udp", addr.Address) |
| } |
| |
| return nil, nil |
| } |
| |
| // LoadServers loads information about known servers from the given JSON data. |
| // It only extracts information about servers with Ed25519 public keys and UDP |
| // address. The number of servers skipped because of unsupported requirements |
| // is returned in numSkipped. |
| func LoadServers(jsonData []byte) (servers []config.Server, numSkipped int, err error) { |
| var serversJSON config.ServersJSON |
| if err := json.Unmarshal(jsonData, &serversJSON); err != nil { |
| return nil, 0, err |
| } |
| |
| seenNames := make(map[string]struct{}) |
| |
| for _, candidate := range serversJSON.Servers { |
| if _, ok := seenNames[candidate.Name]; ok { |
| return nil, 0, fmt.Errorf("client: duplicate server name %q", candidate.Name) |
| } |
| seenNames[candidate.Name] = struct{}{} |
| |
| if candidate.PublicKeyType != "ed25519" { |
| numSkipped++ |
| continue |
| } |
| |
| udpAddr, err := serverUDPAddr(&candidate) |
| |
| if err != nil { |
| return nil, 0, fmt.Errorf("client: server %q lists invalid UDP address: %s", candidate.Name, err) |
| } |
| |
| if udpAddr == nil { |
| numSkipped++ |
| continue |
| } |
| |
| servers = append(servers, candidate) |
| } |
| |
| if len(servers) == 0 { |
| return nil, 0, errors.New("client: no usable servers found") |
| } |
| |
| return servers, 0, nil |
| } |
| |
| // trimChain drops elements from the beginning of chain, as needed, so that its |
| // length is <= n. |
| func trimChain(chain *config.Chain, n int) { |
| if n <= 0 { |
| chain.Links = nil |
| return |
| } |
| |
| if len(chain.Links) <= n { |
| return |
| } |
| |
| numToTrim := len(chain.Links) - n |
| for i := 0; i < numToTrim; i++ { |
| // The NonceOrBlind of the first element is special because |
| // it's an nonce. All the others are blinds and are combined |
| // with the previous reply to make the nonce. That's not |
| // possible for the first element because there is no previous |
| // reply. Therefore, when removing the first element the blind |
| // of the next element needs to be converted to an nonce. |
| nonce := protocol.CalculateChainNonce(chain.Links[0].Reply, chain.Links[1].NonceOrBlind[:]) |
| chain.Links[1].NonceOrBlind = nonce[:] |
| chain.Links = chain.Links[1:] |
| } |
| } |
| |
| // intersection returns the timespan common to all the elements in samples, |
| // which must be aligned to the same base. The caller must ensure that such a |
| // timespan exists. |
| func intersection(samples []*timeSample) *timeSample { |
| ret := &timeSample{ |
| base: samples[0].base, |
| min: new(big.Int).Set(samples[0].min), |
| max: new(big.Int).Set(samples[0].max), |
| } |
| |
| for _, sample := range samples[1:] { |
| if ret.min.Cmp(sample.min) < 0 { |
| ret.min.Set(sample.min) |
| } |
| if ret.max.Cmp(sample.max) > 0 { |
| ret.max.Set(sample.max) |
| } |
| } |
| |
| return ret |
| } |
| |
| // findNOverlapping finds an n-element subset of samples where all the |
| // members overlap. It returns the intersection if such a subset exists. |
| func findNOverlapping(samples []*timeSample, n int) (sampleIntersection *timeSample, ok bool) { |
| switch { |
| case n <= 0: |
| return nil, false |
| case n == 1: |
| return samples[0], true |
| } |
| |
| overlapping := make([]*timeSample, 0, n) |
| |
| for _, initial := range samples { |
| // An intersection of any subset of intervals will be an interval that contains |
| // the starting point of one of the intervals (possibly as its own starting point). |
| point := initial.min |
| |
| for _, candidate := range samples { |
| if candidate.contains(point) { |
| overlapping = append(overlapping, candidate) |
| } |
| |
| if len(overlapping) == n { |
| return intersection(overlapping), true |
| } |
| } |
| |
| overlapping = overlapping[:0] |
| } |
| |
| return nil, false |
| } |
| |
| // TimeResult is the result of trying to establish the current time by querying |
| // a number of servers. |
| type TimeResult struct { |
| // MonoUTCDelta may be nil, in which case a time could not be |
| // established. Otherwise it contains the difference between the |
| // Roughtime epoch and the monotonic clock. |
| MonoUTCDelta *time.Duration |
| |
| // ServerErrors maps from server name to query error. |
| ServerErrors map[string]error |
| |
| // OutOfRangeAnswer is true if one or more of the queries contained a |
| // significantly incorrect time, as defined by MaxDifference. In this |
| // case, the reply will have been recorded in the chain. |
| OutOfRangeAnswer bool |
| } |
| |
| // EstablishTime queries a number of servers until it has a quorum of |
| // overlapping results, or it runs out of servers. Results from the querying |
| // the servers are appended to chain. |
| func (c *Client) EstablishTime(chain *config.Chain, quorum int, servers []config.Server) (TimeResult, error) { |
| perm := c.permutation(len(servers)) |
| var samples []*timeSample |
| var intersection *timeSample |
| var result TimeResult |
| |
| for len(perm) > 0 { |
| server := &servers[perm[0]] |
| perm = perm[1:] |
| |
| sample, err := c.query(server, chain) |
| if err != nil { |
| if result.ServerErrors == nil { |
| result.ServerErrors = make(map[string]error) |
| } |
| result.ServerErrors[server.Name] = err |
| continue |
| } |
| |
| if len(samples) > 0 { |
| sample.alignTo(samples[0]) |
| } |
| samples = append(samples, sample) |
| |
| var ok bool |
| if intersection, ok = findNOverlapping(samples, quorum); ok { |
| break |
| } |
| intersection = nil |
| } |
| |
| if intersection == nil { |
| return result, nil |
| } |
| midpoint := intersection.midpoint() |
| |
| maxDifference := new(big.Int).SetUint64(uint64(c.maxDifference() / time.Microsecond)) |
| for _, sample := range samples { |
| delta := new(big.Int).Sub(midpoint, sample.midpoint()) |
| delta.Abs(delta) |
| |
| if delta.Cmp(maxDifference) > 0 { |
| result.OutOfRangeAnswer = true |
| break |
| } |
| } |
| |
| midpoint.Mul(midpoint, big.NewInt(1000)) |
| delta := new(big.Int).Sub(midpoint, intersection.base) |
| if delta.BitLen() > 63 { |
| return result, errors.New("client: cannot represent difference between monotonic and UTC time") |
| } |
| monoUTCDelta := time.Duration(delta.Int64()) |
| result.MonoUTCDelta = &monoUTCDelta |
| |
| return result, nil |
| } |