import datetime import os import yaml import uuid import json import hashlib from typing import TypeVar, Generic, List, Optional from pydantic import BaseModel, StrictStr, Field from sqlalchemy.ext.declarative import declarative_base from osis.datatools import normalize_email, normalize_phone from sqlalchemy import ( create_engine, Column, Integer, String, DateTime, TIMESTAMP, func, Boolean, Date, inspect, text, bindparam, ) from sqlalchemy.orm import sessionmaker from sqlalchemy.dialects.postgresql import JSONB, JSON import logging from termcolor import colored from osis.db import DB, DBType # type: ignore # Set up logging logging.basicConfig(level=logging.INFO, format="%(message)s") logger = logging.getLogger(__name__) def calculate_months( investment_date: datetime.date, conversion_date: datetime.date ) -> float: delta = conversion_date - investment_date days_in_month = 30.44 months = delta.days / days_in_month return months def indexed_field(cls): cls.__index_fields__ = dict() for name, field in cls.__fields__.items(): if field.json_schema_extra is not None: for cat in ["index", "indexft", "indexphone", "indexemail", "human"]: if field.json_schema_extra.get(cat, False): if name not in cls.__index_fields__: cls.__index_fields__[name] = dict() # print(f"{cls.__name__} found index name:{name} cat:{cat}") cls.__index_fields__[name][cat] = field.annotation if cat in ["indexphone", "indexemail"]: cls.__index_fields__[name]["indexft"] = field.annotation return cls @indexed_field class MyBaseModel(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) name: StrictStr = Field(default="", index=True, human=True) description: StrictStr = Field(default="") lasthash: StrictStr = Field(default="") creation_date: int = Field( default_factory=lambda: int(datetime.datetime.now().timestamp()) ) mod_date: int = Field( def ault_factory=lambda: int(datetime.datetime.now().timestamp()) ) def pre_save(self): self.mod_date = int(datetime.datetime.now().timestamp()) print("pre-save") # for fieldname, typedict in self.__class__.__index_fields__.items(): # v= self.__dict__[fieldname] # if 'indexphone' in typedict: # self.__dict__[fieldname]=[normalize_phone(i) for i in v.split(",")].uniq() # if 'indexemail' in typedict: # self.__dict__[fieldname]=[normalize_email(i) for i in v.split(",")].uniq() # return ",".join(emails) # print(field) # #if field not in ["id", "name","creation_date", "mod_date"]: # from IPython import embed; embed() def yaml_get(self) -> str: data = self.dict() return yaml.dump(data, sort_keys=True, default_flow_style=False) def json_get(self) -> str: data = self.dict() # return self.model_dump_json() return json.dumps(data, sort_keys=True, indent=2) def hash(self) -> str: data = self.dict() data.pop("lasthash") data.pop("mod_date") data.pop("creation_date") data.pop("id") yaml_string = yaml.dump(data, sort_keys=True, default_flow_style=False) # Encode the YAML string to bytes using UTF-8 encoding yaml_bytes = yaml_string.encode("utf-8") self.lasthash = hashlib.md5(yaml_bytes).hexdigest() return self.lasthash def doc_id(self, partition: str) -> str: return f"{partition}:{self.id}" def __str__(self): return self.json_get() T = TypeVar("T", bound=MyBaseModel) class MyBaseFactory(Generic[T]): def __init__( self, model_cls: type[T], db: DB, use_fs: bool = True, keep_history: bool = False, reset: bool = False, load: bool = False, human_readable: bool = True, ): self.mycat = model_cls.__name__.lower() self.description = "" self.model_cls = model_cls self.engine = create_engine(db.cfg.url()) self.Session = sessionmaker(bind=self.engine) self.use_fs = use_fs self.human_readable = human_readable self.keep_history = keep_history self.db = db dbcat = db.dbcat_new(cat=self.mycat, reset=reset) self.db_cat = dbcat self.ft_table_name = f"{self.mycat}_ft" self._init_db_schema(reset=reset) if self.use_fs: self._check_db_schema() else: if not self._check_db_schema_ok(): raise RuntimeError( "DB schema changed in line to model used, need to find ways how to migrate" ) if reset: self.db_cat.reset() self._reset_db() if load: self.load() def _reset_db(self): logger.info(colored("Resetting database...", "red")) with self.engine.connect() as connection: cascade = "" if self.db.cfg.db_type == DBType.POSTGRESQL: cascade = " CASCADE" connection.execute(text(f'DROP TABLE IF EXISTS "{self.mycat}"{cascade}')) if self.keep_history: connection.execute( text(f'DROP TABLE IF EXISTS "{self.mycat}_history" {cascade}') ) connection.commit() self._init_db_schema() def _init_db_schema(self, reset: bool = False): # first make sure table is created if needed inspector = inspect(self.engine) if inspector.has_table(self.mycat): if reset: self._reset_db() return print(f"Table {self.mycat} does exist.") Base = declarative_base() def create_model(tablename): class MyModel(Base): __tablename__ = tablename id = Column(String, primary_key=True) name = Column(String, index=True) creation_date = Column(Integer, index=True) mod_date = Column(Integer, index=True) hash = Column(String, index=True) data = Column(JSON) version = Column(Integer) index_fields = self.model_cls.__index_fields__ for field, index_types in index_fields.items(): if "index" in index_types: field_type = index_types["index"] if field not in ["id", "name", "creation_date", "mod_date"]: if field_type == int: locals()[field] = Column(Integer, index=True) elif field_type == datetime.date: locals()[field] = Column(Date, index=True) elif field_type == bool: locals()[field] = Column(Boolean, index=True) else: locals()[field] = Column(String, index=True) create_model_ft() return MyModel def create_model_ft(): index_fields = self.model_cls.__index_fields__ toindex: List[str] = [] for fieldnam, index_types in index_fields.items(): print(f"field name: {fieldnam}") print(f"toindex: {toindex}") if "indexft" in index_types: toindex.append(fieldnam) if len(toindex) > 0: with self.engine.connect() as connection: result = connection.execute( text( "SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name" ), {"table_name": self.ft_table_name}, ) if result.fetchone() is None: # means table does not exist st = text( "CREATE VIRTUAL TABLE :table_name USING fts5(:fields)" ) st = st.bindparams(bindparam("fields", expanding=True)) st = st.bindparams( table_name=self.ft_table_name, fields=toindex ) # TODO: this is not working connection.execute( st, { "table_name": self.ft_table_name, "fields": toindex, }, ) self.table_model = create_model(self.mycat) if self.keep_history: self.history_table_model = create_model( "HistoryTableModel", f"{self.mycat}_history" ) Base.metadata.create_all(self.engine) def _check_db_schema_ok(self) -> bool: inspector = inspect(self.engine) table_name = self.table_model.__tablename__ # Get columns from the database db_columns = {col["name"]: col for col in inspector.get_columns(table_name)} # Get columns from the model model_columns = {c.name: c for c in self.table_model.__table__.columns} # print("model col") # print(model_columns) # Check for columns in model but not in db for col_name, col in model_columns.items(): if col_name not in db_columns: logger.info( colored( f"Column '{col_name}' exists in model but not in database", "red", ) ) return False else: # Check column type db_col = db_columns[col_name] if str(col.type) != str(db_col["type"]): logger.info( colored( f"Column '{col_name}' type mismatch: Model {col.type}, DB {db_col['type']}", "red", ) ) return False # Check for columns in db but not in model for col_name in db_columns: if col_name not in model_columns: logger.info( colored( f"Column '{col_name}' exists in database but not in model", "red", ) ) return False return True def _check_db_schema(self): # check if schema is ok, if not lets reload if self._check_db_schema_ok(): return self.load() def new(self, name: str = "", **kwargs) -> T: o = self.model_cls(name=name, **kwargs) return o def _encode(self, item: T) -> dict: return item.model_dump() def _decode(self, data: str) -> T: if self.use_fs: return self.model_cls(**yaml.load(data, Loader=yaml.Loader)) else: return self.model_cls(**json.loads(data)) def get(self, id: str = "") -> T: if not isinstance(id, str): raise ValueError(f"id needs to be str. Now: {id}") session = self.Session() result = session.query(self.table_model).filter_by(id=id).first() session.close() if result: if self.use_fs: data = self.db_cat.get(id=id) else: data = result.data return self._decode(data) raise ValueError(f"can't find {self.mycat}:{id}") def exists(self, id: str = "") -> bool: if not isinstance(id, str): raise ValueError(f"id needs to be str. Now: {id}") session = self.Session() result = session.query(self.table_model).filter_by(id=id).first() session.close() return result is not None def get_by_name(self, name: str) -> Optional[T]: r = self.list(name=name) if len(r) > 1: raise ValueError(f"found more than 1 object with name {name}") if len(r) < 1: raise ValueError(f"object not found with name {name}") return r[0] def set(self, item: T, ignorefs: bool = False): item.pre_save() new_hash = item.hash() session = self.Session() db_item = session.query(self.table_model).filter_by(id=item.id).first() data = item.model_dump() index_fields = self.model_cls.__index_fields__ to_ft_index = List[str] ft_field_values = [f"'{db_item.id}'"] for field_name, index_types in index_fields.items(): if "indexft" in index_types: to_ft_index.append(field_name) ft_field_values.append(f"'{db_item[field_name]}'") if db_item: if db_item.hash != new_hash: db_item.name = item.name db_item.mod_date = item.mod_date db_item.creation_date = item.creation_date db_item.hash = new_hash if not self.use_fs: db_item.data = data # Update indexed fields for field, val in self.model_cls.__indexed_fields__: # type: ignore if field not in ["id", "name", "creation_date", "mod_date"]: if "indexft" in val: session.execute( f"UPDATE {self.ft_table_name} SET {field} = '{getattr(item, field)}'" ) setattr(db_item, field, getattr(item, field)) if self.keep_history and not self.use_fs: version = ( session.query(func.max(self.history_table_model.version)) .filter_by(id=item.id) .scalar() or 0 ) history_item = self.history_table_model( id=f"{item.id}_{version + 1}", name=item.name, creation_date=item.creation_date, mod_date=item.mod_date, hash=new_hash, data=data, version=version + 1, ) session.add(history_item) if not ignorefs and self.use_fs: self.db_cat.set(data=item.yaml_get(), id=item.id) else: db_item = self.table_model( id=item.id, name=item.name, creation_date=item.creation_date, mod_date=item.mod_date, hash=new_hash, ) if not self.use_fs: db_item.data = item.json_get() session.add(db_item) session.execute( f'INSERT INTO {self.ft_table_name} (id, {", ".join(to_ft_index)}) VALUES ({", ".join(ft_field_values)})' ) if not ignorefs and self.use_fs: self.db_cat.set( data=item.yaml_get(), id=item.id, humanid=self._human_name_get(item) ) # Set indexed fields for field, _ in self.model_cls.__indexed_fields__: # type: ignore if field not in ["id", "name", "creation_date", "mod_date"]: setattr(db_item, field, getattr(item, field)) session.add(db_item) session.commit() session.close() # used for a symlink so its easy for a human to edit def _human_name_get(self, item: T) -> str: humanname = "" if self.human_readable: for fieldhuman, _ in self.model_cls.__human_fields__: # type: ignore if fieldhuman not in ["id", "creation_date", "mod_date"]: humanname += f"{item.__getattribute__(fieldhuman)}_" humanname = humanname.rstrip("_") if humanname == "": raise Exception(f"humanname should not be empty for {item}") return humanname def delete(self, id: str): if not isinstance(id, str): raise ValueError(f"id needs to be str. Now: {id}") session = self.Session() result = session.query(self.table_model).filter_by(id=id).delete() session.execute(f"DELETE FROM {self.ft_table_name} WHERE id={id};") session.commit() session.close() if result > 1: raise ValueError(f"multiple values deleted with id {id}") elif result == 0: raise ValueError(f"no record found with id {id}") if self.use_fs: humanid = "" if self.exists(): item = self.get(id) # so we can remove the link humanid = self._human_name_get(item) self.db_cat.delete(id=id, humanid=humanid) def list( self, id: Optional[str] = None, name: Optional[str] = None, **kwargs ) -> List[T]: session = self.Session() query = session.query(self.table_model) if id: query = query.filter(self.table_model.id == id) if name: query = query.filter(self.table_model.name.ilike(f"%{name}%")) index_fields = self.model_cls.__index_fields__ for key, value in kwargs.items(): if value is None: continue if self.use_fs: query = query.filter(getattr(self.table_model, key) == value) else: if key in index_fields and "indexft" in index_fields[key]: result = session.execute( f'SELECT id From {self.ft_table_name} WHERE {key} MATCH "{value}"' ) ids = [] for _, value in result: ids.append(value) query = query.filter(self.table_model.id in ids) else: query = query.filter( self.table_model.data[key].astext.ilike(f"%{value}%") ) results = query.all() session.close() items = [] for result in results: items.append(self.get(id=result.id)) return items def load(self, reset: bool = False): if self.use_fs: logger.info(colored(f"Reload DB.", "green")) if reset: self._reset_db() # Get all IDs and hashes from the database session = self.Session() db_items = { item.id: item.hash for item in session.query( self.table_model.id, self.table_model.hash ).all() } session.close() done = [] for root, _, files in os.walk(self.db.path): for file in files: if file.endswith(".yaml"): file_path = os.path.join(root, file) with open(file_path, "r") as f: data = yaml.safe_load(f) obj = self._decode(data) myhash = obj.hash() if reset: self.set(obj, ignorefs=True) else: if obj.id in db_items: if db_items[obj.id] != myhash: # Hash mismatch, update the database record self.set(obj, ignorefs=True) else: # New item, add to database self.set(obj, ignorefs=True) done.append(obj.id)