-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Expand file tree
/
Copy pathgpus.go
More file actions
118 lines (104 loc) · 2.49 KB
/
gpus.go
File metadata and controls
118 lines (104 loc) · 2.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package opts
import (
"encoding/csv"
"errors"
"fmt"
"strconv"
"strings"
"github.com/moby/moby/api/types/container"
)
// GpuOpts is a Value type for parsing mounts
type GpuOpts struct {
values []container.DeviceRequest
}
func parseCount(s string) (int, error) {
if s == "all" {
return -1, nil
}
i, err := strconv.Atoi(s)
if err != nil {
var numErr *strconv.NumError
if errors.As(err, &numErr) {
err = numErr.Err
}
return 0, fmt.Errorf(`invalid count (%s): value must be either "all" or an integer: %w`, s, err)
}
return i, nil
}
// Set a new mount value
//
//nolint:gocyclo
func (o *GpuOpts) Set(value string) error {
csvReader := csv.NewReader(strings.NewReader(value))
fields, err := csvReader.Read()
if err != nil {
return err
}
req := container.DeviceRequest{}
seen := map[string]struct{}{}
// Set writable as the default
for _, field := range fields {
key, val, withValue := strings.Cut(field, "=")
if _, ok := seen[key]; ok {
return fmt.Errorf("gpu request key '%s' can be specified only once", key)
}
seen[key] = struct{}{}
if !withValue {
seen["count"] = struct{}{}
req.Count, err = parseCount(key)
if err != nil {
return err
}
continue
}
switch key {
case "driver":
req.Driver = val
case "count":
req.Count, err = parseCount(val)
if err != nil {
return err
}
case "device":
req.DeviceIDs = strings.Split(val, ",")
case "capabilities":
req.Capabilities = [][]string{append(strings.Split(val, ","), "gpu")}
case "options":
r := csv.NewReader(strings.NewReader(val))
optFields, err := r.Read()
if err != nil {
return fmt.Errorf("failed to read gpu options: %w", err)
}
req.Options = ConvertKVStringsToMap(optFields)
default:
return fmt.Errorf("unexpected key '%s' in '%s'", key, field)
}
}
if _, ok := seen["count"]; !ok && req.DeviceIDs == nil {
req.Count = 1
}
if req.Options == nil {
req.Options = make(map[string]string)
}
if req.Capabilities == nil {
req.Capabilities = [][]string{{"gpu"}}
}
o.values = append(o.values, req)
return nil
}
// Type returns the type of this option
func (*GpuOpts) Type() string {
return "gpu-request"
}
// String returns a string repr of this option
func (o *GpuOpts) String() string {
gpus := make([]string, 0, len(o.values))
for _, gpu := range o.values {
gpus = append(gpus, fmt.Sprintf("%v", gpu))
}
return strings.Join(gpus, ", ")
}
// Value returns the mounts
func (o *GpuOpts) Value() []container.DeviceRequest {
return o.values
}