Fri, 05 Feb 2021 12:16:29 +0200
update
| 1 | 1 | #!/usr/bin/env python3 |
| 2 | import io | |
| 3 | import sys | |
| 4 | import sqlalchemy | |
| 5 | import sqlalchemy.orm | |
| 2 | 6 | import datetime |
| 1 | 7 | from datamodel import * |
| 8 | ||
| 9 | ROUTE_TYPES = { | |
| 10 | '0': 'tram', | |
| 11 | '1': 'subway', | |
| 12 | '2': 'rail', | |
| 13 | '3': 'bus', | |
| 14 | '4': 'ferry', | |
| 15 | '5': 'cable-tram', | |
| 16 | '6': 'aerial-lift', | |
| 17 | '7': 'funicular', | |
| 18 | '11': 'trolleybus', | |
| 19 | '12': 'monorail', | |
| 20 | } | |
| 21 | ||
| 22 | def read_csv(file): | |
| 23 | import csv | |
| 24 | reader = csv.reader(file) | |
| 25 | keys = next(reader) | |
| 26 | for i in range(len(keys)): | |
| 27 | keys[i] = keys[i].replace('\ufeff', '').strip() | |
| 28 | for row in reader: | |
| 29 | yield dict(zip(keys, row)) | |
| 30 | ||
| 31 | def load_gtfs_routes(gtfs_zip): | |
| 32 | with gtfs_zip.open('routes.txt') as file: | |
| 33 | for row in read_csv(map(bytes.decode, file)): | |
| 34 | route = GtfsRoute( | |
| 35 | id = row['route_id'], | |
| 36 | reference = row['route_short_name'], | |
| 37 | description = row['route_long_name'], | |
| 38 | type = int(row['route_type']), | |
| 39 | ) | |
| 40 | yield route.id, route | |
| 41 | ||
| 42 | def load_shapes(gtfs_zip): | |
| 43 | from collections import defaultdict | |
| 44 | shapes = dict() | |
| 45 | with gtfs_zip.open('shapes.txt') as file: | |
| 46 | for row in read_csv(map(bytes.decode, file)): | |
| 47 | shape_id = row['shape_id'] | |
| 48 | if shape_id not in shapes: | |
| 49 | shapes[shape_id] = GtfsShape( | |
| 50 | id = shape_id, | |
| 51 | shape_coordinates = '', | |
| 52 | length = 0, | |
| 53 | ) | |
| 54 | shape = shapes[shape_id] | |
| 55 | if len(shape.shape_coordinates) > 0: | |
| 56 | shape.shape_coordinates += ' ' | |
| 57 | shape.shape_coordinates += str.format( | |
| 58 | '{shape_pt_lat} {shape_pt_lon}', | |
| 59 | **row, | |
| 60 | ) | |
| 61 | shape.length = max(shape.length, float(row['shape_dist_traveled'])) | |
| 62 | return shapes.values() | |
| 63 | ||
| 64 | def trip_length(trip, *, shapes): | |
| 65 | if trip.shape_id: | |
| 66 | return dict.get(shapes, trip.shape_id).length * float(profile['metrics']['shape-modifier']) | |
| 67 | else: | |
| 68 | return 0 | |
| 69 | ||
| 70 | def load_trips(gtfs_zip): | |
| 71 | services = set() | |
| 72 | with gtfs_zip.open('trips.txt') as file: | |
| 73 | for row in read_csv(map(bytes.decode, file)): | |
| 74 | if row['service_id'] not in services: | |
| 75 | set.add(services, row['service_id']) | |
| 76 | yield GtfsService(id = row['service_id']) | |
| 77 | yield GtfsTrip( | |
| 78 | id = row['trip_id'], | |
| 79 | route_id = row['route_id'], | |
| 80 | service = row['service_id'], | |
| 81 | shape_id = dict.get(row, 'shape_id') | |
| 82 | ) | |
| 83 | ||
| 84 | def load_stops(gtfs_zip): | |
| 85 | with gtfs_zip.open('stops.txt') as file: | |
| 86 | for row in read_csv(map(bytes.decode, file)): | |
| 87 | lat = float(row['stop_lat']) | |
| 88 | lon = float(row['stop_lon']) | |
| 89 | yield GtfsStop( | |
| 90 | stop_id = row['stop_id'], | |
| 91 | stop_name = row['stop_name'], | |
| 92 | stop_latitude = lat, | |
| 93 | stop_longitude = float(row['stop_lon']), | |
| 94 | ) | |
| 95 | ||
| 2 | 96 | def parse_time(timetext): |
| 97 | hour, minute, second = map(int, timetext.split(':')) | |
| 98 | return datetime.timedelta(hours = hour, minutes = minute, seconds = second) | |
| 99 | ||
| 100 | def load_stop_times(gtfs_zip): | |
| 101 | with gtfs_zip.open('stop_times.txt') as file: | |
| 102 | for row in read_csv(map(bytes.decode, file)): | |
| 103 | yield GtfsStopTime( | |
| 104 | trip_id = row['trip_id'], | |
| 105 | stop_id = row['stop_id'], | |
| 106 | arrival_time = parse_time(row['arrival_time']), | |
| 107 | departure_time = parse_time(row['departure_time']), | |
| 108 | stop_sequence = int(row['stop_sequence']), | |
| 109 | shape_distance_traveled = float(row['shape_dist_traveled']), | |
| 110 | ) | |
| 111 | ||
| 1 | 112 | def gtfs_stop_spatial_testing(session, regions): |
| 113 | print('Finding out in which regions bus stops are...') | |
| 114 | from compute_regions import RegionTester | |
| 115 | regiontester = RegionTester(regions) | |
| 116 | for bus_stop in session.query(GtfsStop): | |
| 117 | classification = regiontester( | |
| 118 | latitude = bus_stop.stop_latitude, | |
| 119 | longitude = bus_stop.stop_longitude, | |
| 120 | ) | |
| 121 | if classification: | |
| 122 | bus_stop.stop_region = classification.region | |
| 123 | bus_stop.stop_region_major = classification.region_class == 'major' | |
| 124 | ||
| 125 | def load_with_loading_text(fn, what, device): | |
| 126 | print( | |
| 127 | str.format('Loading {}s... ', what), | |
| 128 | file = device, | |
| 129 | end = '', | |
| 130 | flush = True, | |
| 131 | ) | |
| 132 | result = fn() | |
| 133 | print( | |
| 134 | str.format( | |
| 135 | '{n} {what}s', | |
| 136 | n = len(result if type(result) is not tuple else result[0]), | |
| 137 | what = what, | |
| 138 | ), | |
| 139 | file = device, | |
| 140 | ) | |
| 141 | return result | |
| 142 | ||
| 143 | def load_gtfs( | |
| 144 | gtfs_zip_path, | |
| 145 | *, | |
| 146 | profile, | |
| 147 | session, | |
| 148 | device = sys.stderr | |
| 149 | ): | |
| 150 | from zipfile import ZipFile | |
| 151 | with ZipFile(gtfs_zip_path) as gtfs_zip: | |
| 152 | print('Loading routes...') | |
| 153 | for route_id, route in load_gtfs_routes(gtfs_zip): | |
| 154 | session.add(route) | |
| 155 | print('Loading stops...') | |
| 156 | for stop in load_stops(gtfs_zip): | |
| 157 | session.add(stop) | |
| 2 | 158 | session.commit() |
| 1 | 159 | print('Loading shapes...') |
| 160 | for shape in load_shapes(gtfs_zip): | |
| 161 | session.add(shape) | |
| 2 | 162 | session.commit() |
| 1 | 163 | print('Loading trips...') |
| 164 | for trip_or_service in load_trips(gtfs_zip): | |
| 165 | session.add(trip_or_service) | |
| 2 | 166 | session.commit() |
| 167 | print('Loading stop times...') | |
| 168 | for i, stop_time in enumerate(load_stop_times(gtfs_zip)): | |
| 169 | if i & 0xffff == 0: | |
| 170 | # commit every now and then to keep RAM usage under control | |
| 171 | session.commit() | |
| 172 | session.add(stop_time) | |
| 173 | session.commit() | |
| 1 | 174 | |
| 175 | def parse_yesno(value): | |
| 176 | return value and value != 'no' | |
| 177 | ||
| 178 | def regions_to_db(regions): | |
| 179 | from itertools import product | |
| 180 | for region in regions.values(): | |
| 181 | names = dict() | |
| 182 | for prefix, language in product( | |
| 183 | ['', 'short_', 'internal_'], | |
| 2 | 184 | ['fi', 'sv', 'en', 'ja'], |
| 1 | 185 | ): |
| 2 | 186 | key = 'region_' + prefix + 'name_' + language |
| 187 | value = dict.get(region, prefix + 'name:' + language) | |
| 1 | 188 | names[key] = value |
| 189 | yield GtfsRegion( | |
| 190 | **names, | |
| 2 | 191 | ref = region['ref'], |
| 1 | 192 | municipality = dict.get(region, 'municipality'), |
| 193 | external = parse_yesno(dict.get(region, 'external')), | |
| 194 | ) | |
| 195 | ||
| 2 | 196 | def get_args(): |
| 197 | import argparse | |
| 198 | parser = argparse.ArgumentParser() | |
| 199 | parser.add_argument('profile') | |
| 200 | parser.add_argument('gtfs') | |
| 201 | parser.add_argument('--process-only', action = 'store_true') | |
| 202 | return parser.parse_args() | |
| 203 | ||
| 1 | 204 | if __name__ == '__main__': |
| 205 | from configparser import ConfigParser | |
| 206 | from regions import parse_regions | |
| 2 | 207 | args = get_args() |
| 1 | 208 | profile = ConfigParser() |
| 2 | 209 | profile.read(args.profile) |
| 1 | 210 | engine = sqlalchemy.create_engine('sqlite:///gtfs.db') |
| 211 | GtfsBase.metadata.create_all(engine) | |
| 212 | session = sqlalchemy.orm.sessionmaker(bind = engine)() | |
| 213 | regions = parse_regions('föli.osm') | |
| 2 | 214 | if not args.process_only: |
| 215 | for region in regions_to_db(regions): | |
| 216 | session.add(region) | |
| 217 | session.commit() | |
| 218 | buses = load_gtfs(args.gtfs, profile = profile, session = session) | |
| 219 | gtfs_stop_spatial_testing(session = session, regions = regions) | |
| 1 | 220 | session.commit() |