gtfsc.py

Fri, 05 Feb 2021 12:16:29 +0200

author
Teemu Piippo <teemu@hecknology.net>
date
Fri, 05 Feb 2021 12:16:29 +0200
changeset 4
ac067a42b00f
parent 2
7378b802ddf8
permissions
-rwxr-xr-x

update

#!/usr/bin/env python3
import io
import sys
import sqlalchemy
import sqlalchemy.orm
import datetime
from datamodel import *

ROUTE_TYPES = {
	'0': 'tram',
	'1': 'subway',
	'2': 'rail',
	'3': 'bus',
	'4': 'ferry',
	'5': 'cable-tram',
	'6': 'aerial-lift',
	'7': 'funicular',
	'11': 'trolleybus',
	'12': 'monorail',
}

def read_csv(file):
	import csv
	reader = csv.reader(file)
	keys = next(reader)
	for i in range(len(keys)):
		keys[i] = keys[i].replace('\ufeff', '').strip()
	for row in reader:
		yield dict(zip(keys, row))

def load_gtfs_routes(gtfs_zip):
	with gtfs_zip.open('routes.txt') as file:
		for row in read_csv(map(bytes.decode, file)):
			route = GtfsRoute(
				id = row['route_id'],
				reference = row['route_short_name'],
				description = row['route_long_name'],
				type = int(row['route_type']),
			)
			yield route.id, route

def load_shapes(gtfs_zip):
	from collections import defaultdict
	shapes = dict()
	with gtfs_zip.open('shapes.txt') as file:
		for row in read_csv(map(bytes.decode, file)):
			shape_id = row['shape_id']
			if shape_id not in shapes:
				shapes[shape_id] = GtfsShape(
					id = shape_id,
					shape_coordinates = '',
					length = 0,
				)
			shape = shapes[shape_id]
			if len(shape.shape_coordinates) > 0:
				shape.shape_coordinates += ' '
			shape.shape_coordinates += str.format(
				'{shape_pt_lat} {shape_pt_lon}',
				**row,
			)
			shape.length = max(shape.length, float(row['shape_dist_traveled']))
	return shapes.values()

def trip_length(trip, *, shapes):
	if trip.shape_id:
		return dict.get(shapes, trip.shape_id).length * float(profile['metrics']['shape-modifier'])
	else:
		return 0

def load_trips(gtfs_zip):
	services = set()
	with gtfs_zip.open('trips.txt') as file:
		for row in read_csv(map(bytes.decode, file)):
			if row['service_id'] not in services:
				set.add(services, row['service_id'])
				yield GtfsService(id = row['service_id'])
			yield GtfsTrip(
				id = row['trip_id'],
				route_id = row['route_id'],
				service = row['service_id'],
				shape_id = dict.get(row, 'shape_id')
			)

def load_stops(gtfs_zip):
	with gtfs_zip.open('stops.txt') as file:
		for row in read_csv(map(bytes.decode, file)):
			lat = float(row['stop_lat'])
			lon = float(row['stop_lon'])
			yield GtfsStop(
				stop_id = row['stop_id'],
				stop_name = row['stop_name'],
				stop_latitude = lat,
				stop_longitude = float(row['stop_lon']),
			)

def parse_time(timetext):
	hour, minute, second = map(int, timetext.split(':'))
	return datetime.timedelta(hours = hour, minutes = minute, seconds = second)

def load_stop_times(gtfs_zip):
	with gtfs_zip.open('stop_times.txt') as file:
		for row in read_csv(map(bytes.decode, file)):
			yield GtfsStopTime(
				trip_id = row['trip_id'],
				stop_id = row['stop_id'],
				arrival_time = parse_time(row['arrival_time']),
				departure_time = parse_time(row['departure_time']),
				stop_sequence = int(row['stop_sequence']),
				shape_distance_traveled = float(row['shape_dist_traveled']),
			)

def gtfs_stop_spatial_testing(session, regions):
	print('Finding out in which regions bus stops are...')
	from compute_regions import RegionTester
	regiontester = RegionTester(regions)
	for bus_stop in session.query(GtfsStop):
		classification = regiontester(
			latitude = bus_stop.stop_latitude,
			longitude = bus_stop.stop_longitude,
		)
		if classification:
			bus_stop.stop_region = classification.region
			bus_stop.stop_region_major = classification.region_class == 'major'

def load_with_loading_text(fn, what, device):
	print(
		str.format('Loading {}s... ', what),
		file = device,
		end = '',
		flush = True,
	)
	result = fn()
	print(
		str.format(
			'{n} {what}s',
			n = len(result if type(result) is not tuple else result[0]),
			what = what,
		),
		file = device,
	)
	return result

def load_gtfs(
	gtfs_zip_path,
	*,
	profile,
	session,
	device = sys.stderr
):
	from zipfile import ZipFile
	with ZipFile(gtfs_zip_path) as gtfs_zip:
		print('Loading routes...')
		for route_id, route in load_gtfs_routes(gtfs_zip):
			session.add(route)
		print('Loading stops...')
		for stop in load_stops(gtfs_zip):
			session.add(stop)
		session.commit()
		print('Loading shapes...')
		for shape in load_shapes(gtfs_zip):
			session.add(shape)
		session.commit()
		print('Loading trips...')
		for trip_or_service in load_trips(gtfs_zip):
			session.add(trip_or_service)
		session.commit()
		print('Loading stop times...')
		for i, stop_time in enumerate(load_stop_times(gtfs_zip)):
			if i & 0xffff == 0:
				# commit every now and then to keep RAM usage under control
				session.commit()
			session.add(stop_time)
		session.commit()

def parse_yesno(value):
	return value and value != 'no'

def regions_to_db(regions):
	from itertools import product
	for region in regions.values():
		names = dict()
		for prefix, language in product(
			['', 'short_', 'internal_'],
			['fi', 'sv', 'en', 'ja'],
		):
			key = 'region_' + prefix + 'name_' + language
			value = dict.get(region, prefix + 'name:' + language)
			names[key] = value
		yield GtfsRegion(
			**names,
			ref = region['ref'],
			municipality = dict.get(region, 'municipality'),
			external = parse_yesno(dict.get(region, 'external')),
		)

def get_args():
	import argparse
	parser = argparse.ArgumentParser()
	parser.add_argument('profile')
	parser.add_argument('gtfs')
	parser.add_argument('--process-only', action = 'store_true')
	return parser.parse_args()

if __name__ == '__main__':
	from configparser import ConfigParser
	from regions import parse_regions
	args = get_args()
	profile = ConfigParser()
	profile.read(args.profile)
	engine = sqlalchemy.create_engine('sqlite:///gtfs.db')
	GtfsBase.metadata.create_all(engine)
	session = sqlalchemy.orm.sessionmaker(bind = engine)()
	regions = parse_regions('föli.osm')
	if not args.process_only:
		for region in regions_to_db(regions):
			session.add(region)
		session.commit()
		buses = load_gtfs(args.gtfs, profile = profile, session = session)
	gtfs_stop_spatial_testing(session = session, regions = regions)
	session.commit()

mercurial