Viewing file: csv_utils.py (11.34 KB) -rw-r--r-- Select action/file-type: (+) | (+) | (+) | Code (+) | Session (+) | (+) | SDB (+) | (+) | (+) | (+) | (+) | (+) |
""" Peewee helper for loading CSV data into a database.
Load the users CSV file into the database and return a Model for accessing the data:
from playhouse.csv_loader import load_csv db = SqliteDatabase(':memory:') User = load_csv(db, 'users.csv')
Provide explicit field types and/or field names:
fields = [IntegerField(), IntegerField(), DateTimeField(), DecimalField()] field_names = ['from_acct', 'to_acct', 'timestamp', 'amount'] Payments = load_csv(db, 'payments.csv', fields, field_names) """ import csv import datetime import os import re from contextlib import contextmanager try: from StringIO import StringIO except ImportError: from io import StringIO
from peewee import * from peewee import Database from peewee import Func from peewee import PY3
if PY3: basestring = str decode_value = False else: decode_value = True
class _CSVReader(object): @contextmanager def get_reader(self, file_or_name, **reader_kwargs): is_file = False if isinstance(file_or_name, basestring): fh = open(file_or_name, 'r') elif isinstance(file_or_name, StringIO): fh = file_or_name fh.seek(0) else: fh = file_or_name fh.seek(0) is_file = True reader = csv.reader(fh, **reader_kwargs) yield reader if is_file: fh.close()
def convert_field(field_class, **field_kwargs): def decorator(fn): fn.field = lambda: field_class(**field_kwargs) return fn return decorator
class RowConverter(_CSVReader): """ Simple introspection utility to convert a CSV file into a list of headers and column types.
:param database: a peewee Database object. :param bool has_header: whether the first row of CSV is a header row. :param int sample_size: number of rows to introspect """ date_formats = [ '%Y-%m-%d', '%m/%d/%Y']
datetime_formats = [ '%Y-%m-%d %H:%M:%S', '%Y-%m-%d %H:%M:%S.%f']
def __init__(self, database, has_header=True, sample_size=10): self.database = database self.has_header = has_header self.sample_size = sample_size
def matches_date(self, value, formats): for fmt in formats: try: datetime.datetime.strptime(value, fmt) except ValueError: pass else: return True
@convert_field(IntegerField, default=0) def is_integer(self, value): return value.isdigit()
@convert_field(FloatField, default=0) def is_float(self, value): try: float(value) except (ValueError, TypeError): pass else: return True
@convert_field(DateTimeField, null=True) def is_datetime(self, value): return self.matches_date(value, self.datetime_formats)
@convert_field(DateField, null=True) def is_date(self, value): return self.matches_date(value, self.date_formats)
@convert_field(BareField, default='') def default(self, value): return True
def extract_rows(self, file_or_name, **reader_kwargs): """ Extract `self.sample_size` rows from the CSV file and analyze their data-types.
:param str file_or_name: A string filename or a file handle. :param reader_kwargs: Arbitrary parameters to pass to the CSV reader. :returns: A 2-tuple containing a list of headers and list of rows read from the CSV file. """ rows = [] rows_to_read = self.sample_size with self.get_reader(file_or_name, **reader_kwargs) as reader: if self.has_header: rows_to_read += 1 for i, row in enumerate(reader): rows.append(row) if i == self.sample_size: break if self.has_header: header, rows = rows[0], rows[1:] else: header = ['field_%d' % i for i in range(len(rows[0]))] return header, rows
def get_checks(self): """Return a list of functions to use when testing values.""" return [ self.is_date, self.is_datetime, self.is_integer, self.is_float, self.default]
def analyze(self, rows): """ Analyze the given rows and try to determine the type of value stored.
:param list rows: A list-of-lists containing one or more rows from a csv file. :returns: A list of peewee Field objects for each column in the CSV. """ transposed = zip(*rows) checks = self.get_checks() column_types = [] for i, column in enumerate(transposed): # Remove any empty values. col_vals = [val for val in column if val != ''] for check in checks: results = set(check(val) for val in col_vals) if all(results): column_types.append(check.field()) break
return column_types
class Loader(_CSVReader): """ Load the contents of a CSV file into a database and return a model class suitable for working with the CSV data.
:param db_or_model: a peewee Database instance or a Model class. :param file_or_name: the filename of the CSV file *or* a file handle. :param list fields: A list of peewee Field() instances appropriate to the values in the CSV file. :param list field_names: A list of names to use for the fields. :param bool has_header: Whether the first row of the CSV file is a header. :param int sample_size: Number of rows to introspect if fields are not defined. :param converter: A RowConverter instance to use. :param str db_table: Name of table to store data in (if not specified, the table name will be derived from the CSV filename). :param reader_kwargs: Arbitrary arguments to pass to the CSV reader. """ def __init__(self, db_or_model, file_or_name, fields=None, field_names=None, has_header=True, sample_size=10, converter=None, db_table=None, pk_in_csv=False, **reader_kwargs): self.file_or_name = file_or_name self.fields = fields self.field_names = field_names self.has_header = has_header self.sample_size = sample_size self.converter = converter self.reader_kwargs = reader_kwargs
if isinstance(file_or_name, basestring): self.filename = file_or_name elif isinstance(file_or_name, StringIO): self.filename = 'data.csv' else: self.filename = file_or_name.name
if isinstance(db_or_model, Database): self.database = db_or_model self.model = None self.db_table = ( db_table or os.path.splitext(os.path.basename(self.filename))[0]) else: self.model = db_or_model self.database = self.model._meta.database self.db_table = self.model._meta.db_table self.fields = self.model._meta.sorted_fields self.field_names = self.model._meta.sorted_field_names # If using an auto-incrementing primary key, ignore it unless we # are told the primary key is included in the CSV. if self.model._meta.auto_increment and not pk_in_csv: self.fields = self.fields[1:] self.field_names = self.field_names[1:]
def clean_field_name(self, s): return re.sub('[^a-z0-9]+', '_', s.lower())
def get_converter(self): return self.converter or RowConverter( self.database, has_header=self.has_header, sample_size=self.sample_size)
def analyze_csv(self): converter = self.get_converter() header, rows = converter.extract_rows( self.file_or_name, **self.reader_kwargs) if rows: self.fields = converter.analyze(rows) else: self.fields = [converter.default.field() for _ in header] if not self.field_names: self.field_names = [self.clean_field_name(col) for col in header]
def get_model_class(self, field_names, fields): if self.model: return self.model attrs = dict(zip(field_names, fields)) if 'id' not in attrs: attrs['_auto_pk'] = PrimaryKeyField() elif isinstance(attrs['id'], IntegerField): attrs['id'] = PrimaryKeyField() klass = type(self.db_table.title(), (Model,), attrs) klass._meta.database = self.database klass._meta.db_table = self.db_table return klass
def load(self): if not self.fields: self.analyze_csv() if not self.field_names and not self.has_header: self.field_names = [ 'field_%d' % i for i in range(len(self.fields))]
reader_obj = self.get_reader(self.file_or_name, **self.reader_kwargs) with reader_obj as reader: if not self.field_names: row = next(reader) self.field_names = [self.clean_field_name(col) for col in row] elif self.has_header: next(reader)
ModelClass = self.get_model_class(self.field_names, self.fields)
with self.database.transaction(): ModelClass.create_table(True) for row in reader: insert = {} for field_name, value in zip(self.field_names, row): if value: if decode_value: value = value.decode('utf-8') insert[field_name] = value if insert: ModelClass.insert(**insert).execute()
return ModelClass
def load_csv(db_or_model, file_or_name, fields=None, field_names=None, has_header=True, sample_size=10, converter=None, db_table=None, pk_in_csv=False, **reader_kwargs): loader = Loader( db_or_model=db_or_model, file_or_name=file_or_name, fields=fields, field_names=field_names, has_header=has_header, sample_size=sample_size, converter=converter, db_table=db_table, pk_in_csv=pk_in_csv, **reader_kwargs) return loader.load()
load_csv.__doc__ = Loader.__doc__
def dump_csv(query, file_or_name, include_header=True, close_file=True, append=True, csv_writer=None): """ Create a CSV dump of a query. """ if isinstance(file_or_name, basestring): fh = open(file_or_name, append and 'a' or 'w') else: fh = file_or_name if append: fh.seek(0, 2) writer = csv_writer or csv.writer( fh, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
if include_header: header = [] for idx, node in enumerate(query._select): if node._alias: header.append(node._alias) elif isinstance(node, (Field, Func)): header.append(node.name) else: header.append('col_%s' % idx) writer.writerow(header)
for row in query.tuples().iterator(): writer.writerow(row)
if close_file: fh.close()
return fh
|