Fri, 05 Feb 2021 12:16:29 +0200
update
#!/usr/bin/env python3 import io import sys import sqlalchemy import sqlalchemy.orm import datetime from datamodel import * ROUTE_TYPES = { '0': 'tram', '1': 'subway', '2': 'rail', '3': 'bus', '4': 'ferry', '5': 'cable-tram', '6': 'aerial-lift', '7': 'funicular', '11': 'trolleybus', '12': 'monorail', } def read_csv(file): import csv reader = csv.reader(file) keys = next(reader) for i in range(len(keys)): keys[i] = keys[i].replace('\ufeff', '').strip() for row in reader: yield dict(zip(keys, row)) def load_gtfs_routes(gtfs_zip): with gtfs_zip.open('routes.txt') as file: for row in read_csv(map(bytes.decode, file)): route = GtfsRoute( id = row['route_id'], reference = row['route_short_name'], description = row['route_long_name'], type = int(row['route_type']), ) yield route.id, route def load_shapes(gtfs_zip): from collections import defaultdict shapes = dict() with gtfs_zip.open('shapes.txt') as file: for row in read_csv(map(bytes.decode, file)): shape_id = row['shape_id'] if shape_id not in shapes: shapes[shape_id] = GtfsShape( id = shape_id, shape_coordinates = '', length = 0, ) shape = shapes[shape_id] if len(shape.shape_coordinates) > 0: shape.shape_coordinates += ' ' shape.shape_coordinates += str.format( '{shape_pt_lat} {shape_pt_lon}', **row, ) shape.length = max(shape.length, float(row['shape_dist_traveled'])) return shapes.values() def trip_length(trip, *, shapes): if trip.shape_id: return dict.get(shapes, trip.shape_id).length * float(profile['metrics']['shape-modifier']) else: return 0 def load_trips(gtfs_zip): services = set() with gtfs_zip.open('trips.txt') as file: for row in read_csv(map(bytes.decode, file)): if row['service_id'] not in services: set.add(services, row['service_id']) yield GtfsService(id = row['service_id']) yield GtfsTrip( id = row['trip_id'], route_id = row['route_id'], service = row['service_id'], shape_id = dict.get(row, 'shape_id') ) def load_stops(gtfs_zip): with gtfs_zip.open('stops.txt') as file: for row in read_csv(map(bytes.decode, file)): lat = float(row['stop_lat']) lon = float(row['stop_lon']) yield GtfsStop( stop_id = row['stop_id'], stop_name = row['stop_name'], stop_latitude = lat, stop_longitude = float(row['stop_lon']), ) def parse_time(timetext): hour, minute, second = map(int, timetext.split(':')) return datetime.timedelta(hours = hour, minutes = minute, seconds = second) def load_stop_times(gtfs_zip): with gtfs_zip.open('stop_times.txt') as file: for row in read_csv(map(bytes.decode, file)): yield GtfsStopTime( trip_id = row['trip_id'], stop_id = row['stop_id'], arrival_time = parse_time(row['arrival_time']), departure_time = parse_time(row['departure_time']), stop_sequence = int(row['stop_sequence']), shape_distance_traveled = float(row['shape_dist_traveled']), ) def gtfs_stop_spatial_testing(session, regions): print('Finding out in which regions bus stops are...') from compute_regions import RegionTester regiontester = RegionTester(regions) for bus_stop in session.query(GtfsStop): classification = regiontester( latitude = bus_stop.stop_latitude, longitude = bus_stop.stop_longitude, ) if classification: bus_stop.stop_region = classification.region bus_stop.stop_region_major = classification.region_class == 'major' def load_with_loading_text(fn, what, device): print( str.format('Loading {}s... ', what), file = device, end = '', flush = True, ) result = fn() print( str.format( '{n} {what}s', n = len(result if type(result) is not tuple else result[0]), what = what, ), file = device, ) return result def load_gtfs( gtfs_zip_path, *, profile, session, device = sys.stderr ): from zipfile import ZipFile with ZipFile(gtfs_zip_path) as gtfs_zip: print('Loading routes...') for route_id, route in load_gtfs_routes(gtfs_zip): session.add(route) print('Loading stops...') for stop in load_stops(gtfs_zip): session.add(stop) session.commit() print('Loading shapes...') for shape in load_shapes(gtfs_zip): session.add(shape) session.commit() print('Loading trips...') for trip_or_service in load_trips(gtfs_zip): session.add(trip_or_service) session.commit() print('Loading stop times...') for i, stop_time in enumerate(load_stop_times(gtfs_zip)): if i & 0xffff == 0: # commit every now and then to keep RAM usage under control session.commit() session.add(stop_time) session.commit() def parse_yesno(value): return value and value != 'no' def regions_to_db(regions): from itertools import product for region in regions.values(): names = dict() for prefix, language in product( ['', 'short_', 'internal_'], ['fi', 'sv', 'en', 'ja'], ): key = 'region_' + prefix + 'name_' + language value = dict.get(region, prefix + 'name:' + language) names[key] = value yield GtfsRegion( **names, ref = region['ref'], municipality = dict.get(region, 'municipality'), external = parse_yesno(dict.get(region, 'external')), ) def get_args(): import argparse parser = argparse.ArgumentParser() parser.add_argument('profile') parser.add_argument('gtfs') parser.add_argument('--process-only', action = 'store_true') return parser.parse_args() if __name__ == '__main__': from configparser import ConfigParser from regions import parse_regions args = get_args() profile = ConfigParser() profile.read(args.profile) engine = sqlalchemy.create_engine('sqlite:///gtfs.db') GtfsBase.metadata.create_all(engine) session = sqlalchemy.orm.sessionmaker(bind = engine)() regions = parse_regions('föli.osm') if not args.process_only: for region in regions_to_db(regions): session.add(region) session.commit() buses = load_gtfs(args.gtfs, profile = profile, session = session) gtfs_stop_spatial_testing(session = session, regions = regions) session.commit()