Fri, 05 Feb 2021 12:16:29 +0200
update
1 | 1 | #!/usr/bin/env python3 |
2 | import io | |
3 | import sys | |
4 | import sqlalchemy | |
5 | import sqlalchemy.orm | |
2 | 6 | import datetime |
1 | 7 | from datamodel import * |
8 | ||
9 | ROUTE_TYPES = { | |
10 | '0': 'tram', | |
11 | '1': 'subway', | |
12 | '2': 'rail', | |
13 | '3': 'bus', | |
14 | '4': 'ferry', | |
15 | '5': 'cable-tram', | |
16 | '6': 'aerial-lift', | |
17 | '7': 'funicular', | |
18 | '11': 'trolleybus', | |
19 | '12': 'monorail', | |
20 | } | |
21 | ||
22 | def read_csv(file): | |
23 | import csv | |
24 | reader = csv.reader(file) | |
25 | keys = next(reader) | |
26 | for i in range(len(keys)): | |
27 | keys[i] = keys[i].replace('\ufeff', '').strip() | |
28 | for row in reader: | |
29 | yield dict(zip(keys, row)) | |
30 | ||
31 | def load_gtfs_routes(gtfs_zip): | |
32 | with gtfs_zip.open('routes.txt') as file: | |
33 | for row in read_csv(map(bytes.decode, file)): | |
34 | route = GtfsRoute( | |
35 | id = row['route_id'], | |
36 | reference = row['route_short_name'], | |
37 | description = row['route_long_name'], | |
38 | type = int(row['route_type']), | |
39 | ) | |
40 | yield route.id, route | |
41 | ||
42 | def load_shapes(gtfs_zip): | |
43 | from collections import defaultdict | |
44 | shapes = dict() | |
45 | with gtfs_zip.open('shapes.txt') as file: | |
46 | for row in read_csv(map(bytes.decode, file)): | |
47 | shape_id = row['shape_id'] | |
48 | if shape_id not in shapes: | |
49 | shapes[shape_id] = GtfsShape( | |
50 | id = shape_id, | |
51 | shape_coordinates = '', | |
52 | length = 0, | |
53 | ) | |
54 | shape = shapes[shape_id] | |
55 | if len(shape.shape_coordinates) > 0: | |
56 | shape.shape_coordinates += ' ' | |
57 | shape.shape_coordinates += str.format( | |
58 | '{shape_pt_lat} {shape_pt_lon}', | |
59 | **row, | |
60 | ) | |
61 | shape.length = max(shape.length, float(row['shape_dist_traveled'])) | |
62 | return shapes.values() | |
63 | ||
64 | def trip_length(trip, *, shapes): | |
65 | if trip.shape_id: | |
66 | return dict.get(shapes, trip.shape_id).length * float(profile['metrics']['shape-modifier']) | |
67 | else: | |
68 | return 0 | |
69 | ||
70 | def load_trips(gtfs_zip): | |
71 | services = set() | |
72 | with gtfs_zip.open('trips.txt') as file: | |
73 | for row in read_csv(map(bytes.decode, file)): | |
74 | if row['service_id'] not in services: | |
75 | set.add(services, row['service_id']) | |
76 | yield GtfsService(id = row['service_id']) | |
77 | yield GtfsTrip( | |
78 | id = row['trip_id'], | |
79 | route_id = row['route_id'], | |
80 | service = row['service_id'], | |
81 | shape_id = dict.get(row, 'shape_id') | |
82 | ) | |
83 | ||
84 | def load_stops(gtfs_zip): | |
85 | with gtfs_zip.open('stops.txt') as file: | |
86 | for row in read_csv(map(bytes.decode, file)): | |
87 | lat = float(row['stop_lat']) | |
88 | lon = float(row['stop_lon']) | |
89 | yield GtfsStop( | |
90 | stop_id = row['stop_id'], | |
91 | stop_name = row['stop_name'], | |
92 | stop_latitude = lat, | |
93 | stop_longitude = float(row['stop_lon']), | |
94 | ) | |
95 | ||
2 | 96 | def parse_time(timetext): |
97 | hour, minute, second = map(int, timetext.split(':')) | |
98 | return datetime.timedelta(hours = hour, minutes = minute, seconds = second) | |
99 | ||
100 | def load_stop_times(gtfs_zip): | |
101 | with gtfs_zip.open('stop_times.txt') as file: | |
102 | for row in read_csv(map(bytes.decode, file)): | |
103 | yield GtfsStopTime( | |
104 | trip_id = row['trip_id'], | |
105 | stop_id = row['stop_id'], | |
106 | arrival_time = parse_time(row['arrival_time']), | |
107 | departure_time = parse_time(row['departure_time']), | |
108 | stop_sequence = int(row['stop_sequence']), | |
109 | shape_distance_traveled = float(row['shape_dist_traveled']), | |
110 | ) | |
111 | ||
1 | 112 | def gtfs_stop_spatial_testing(session, regions): |
113 | print('Finding out in which regions bus stops are...') | |
114 | from compute_regions import RegionTester | |
115 | regiontester = RegionTester(regions) | |
116 | for bus_stop in session.query(GtfsStop): | |
117 | classification = regiontester( | |
118 | latitude = bus_stop.stop_latitude, | |
119 | longitude = bus_stop.stop_longitude, | |
120 | ) | |
121 | if classification: | |
122 | bus_stop.stop_region = classification.region | |
123 | bus_stop.stop_region_major = classification.region_class == 'major' | |
124 | ||
125 | def load_with_loading_text(fn, what, device): | |
126 | print( | |
127 | str.format('Loading {}s... ', what), | |
128 | file = device, | |
129 | end = '', | |
130 | flush = True, | |
131 | ) | |
132 | result = fn() | |
133 | print( | |
134 | str.format( | |
135 | '{n} {what}s', | |
136 | n = len(result if type(result) is not tuple else result[0]), | |
137 | what = what, | |
138 | ), | |
139 | file = device, | |
140 | ) | |
141 | return result | |
142 | ||
143 | def load_gtfs( | |
144 | gtfs_zip_path, | |
145 | *, | |
146 | profile, | |
147 | session, | |
148 | device = sys.stderr | |
149 | ): | |
150 | from zipfile import ZipFile | |
151 | with ZipFile(gtfs_zip_path) as gtfs_zip: | |
152 | print('Loading routes...') | |
153 | for route_id, route in load_gtfs_routes(gtfs_zip): | |
154 | session.add(route) | |
155 | print('Loading stops...') | |
156 | for stop in load_stops(gtfs_zip): | |
157 | session.add(stop) | |
2 | 158 | session.commit() |
1 | 159 | print('Loading shapes...') |
160 | for shape in load_shapes(gtfs_zip): | |
161 | session.add(shape) | |
2 | 162 | session.commit() |
1 | 163 | print('Loading trips...') |
164 | for trip_or_service in load_trips(gtfs_zip): | |
165 | session.add(trip_or_service) | |
2 | 166 | session.commit() |
167 | print('Loading stop times...') | |
168 | for i, stop_time in enumerate(load_stop_times(gtfs_zip)): | |
169 | if i & 0xffff == 0: | |
170 | # commit every now and then to keep RAM usage under control | |
171 | session.commit() | |
172 | session.add(stop_time) | |
173 | session.commit() | |
1 | 174 | |
175 | def parse_yesno(value): | |
176 | return value and value != 'no' | |
177 | ||
178 | def regions_to_db(regions): | |
179 | from itertools import product | |
180 | for region in regions.values(): | |
181 | names = dict() | |
182 | for prefix, language in product( | |
183 | ['', 'short_', 'internal_'], | |
2 | 184 | ['fi', 'sv', 'en', 'ja'], |
1 | 185 | ): |
2 | 186 | key = 'region_' + prefix + 'name_' + language |
187 | value = dict.get(region, prefix + 'name:' + language) | |
1 | 188 | names[key] = value |
189 | yield GtfsRegion( | |
190 | **names, | |
2 | 191 | ref = region['ref'], |
1 | 192 | municipality = dict.get(region, 'municipality'), |
193 | external = parse_yesno(dict.get(region, 'external')), | |
194 | ) | |
195 | ||
2 | 196 | def get_args(): |
197 | import argparse | |
198 | parser = argparse.ArgumentParser() | |
199 | parser.add_argument('profile') | |
200 | parser.add_argument('gtfs') | |
201 | parser.add_argument('--process-only', action = 'store_true') | |
202 | return parser.parse_args() | |
203 | ||
1 | 204 | if __name__ == '__main__': |
205 | from configparser import ConfigParser | |
206 | from regions import parse_regions | |
2 | 207 | args = get_args() |
1 | 208 | profile = ConfigParser() |
2 | 209 | profile.read(args.profile) |
1 | 210 | engine = sqlalchemy.create_engine('sqlite:///gtfs.db') |
211 | GtfsBase.metadata.create_all(engine) | |
212 | session = sqlalchemy.orm.sessionmaker(bind = engine)() | |
213 | regions = parse_regions('föli.osm') | |
2 | 214 | if not args.process_only: |
215 | for region in regions_to_db(regions): | |
216 | session.add(region) | |
217 | session.commit() | |
218 | buses = load_gtfs(args.gtfs, profile = profile, session = session) | |
219 | gtfs_stop_spatial_testing(session = session, regions = regions) | |
1 | 220 | session.commit() |