Source code for drepr.planning.drepr_model_alignments

from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass, field
from typing import Optional

from graph.retworkx import RetworkXStrDiGraph

from drepr.models.prelude import (
    AlignedStep,
    Alignment,
    Attr,
    AttrId,
    AutoAlignment,
    Cardinality,
    DRepr,
    IdenticalAlign,
    IndexExpr,
    PreprocessResourceOutput,
    RangeAlignment,
    RangeExpr,
    SetIndexExpr,
    ValueAlignment,
)


[docs]@dataclass class DReprModelAlignments: desc: DRepr aligns: dict[tuple[AttrId, AttrId], list[Alignment]]
[docs] @staticmethod def create(desc: DRepr) -> DReprModelAlignments: """Infer all possible alignments from the model.""" aligns = {(src.id, tgt.id): [] for src in desc.attrs for tgt in desc.attrs} for a in desc.aligns: if isinstance(a, RangeAlignment): aligns[(a.source, a.target)] = [a] aligns[(a.target, a.source)] = [a.swap()] elif isinstance(a, ValueAlignment): aligns[(a.source, a.target)] = [a] aligns[(a.target, a.source)] = [a.swap()] elif isinstance(a, AutoAlignment): if a.attrs is None: attrs = desc.attrs else: attrs = [desc.get_attr_by_id(attr) for attr in a.attrs] # now for each pair of attributes, we look at there steps and align them one by one for source in attrs: for target in attrs: if ( source.id == target.id or len(aligns[source.id, target.id]) > 0 ): continue if source.resource_id != target.resource_id: # check if source & target is derived from the same resource source_resource = desc.get_resource_by_id( source.resource_id ) assert source_resource is not None if not ( isinstance(source_resource, PreprocessResourceOutput) and source_resource.get_preprocessing_original_resource_id() == target.resource_id ): continue new_align = DReprModelAlignments.auto_align(source, target) if new_align is not None: aligns[(source.id, target.id)] = [new_align] aligns[(target.id, source.id)] = [new_align.swap()] elif isinstance(a, IdenticalAlign): raise Exception( "Unreachable! Users should not provide identical alignment in the model" ) for a in desc.attrs: aligns[(a.id, a.id)] = [IdenticalAlign(a.id, a.id)] inf = DReprModelAlignments(desc, aligns) inf.inference() return inf
[docs] def get_alignments(self, source: AttrId, target: AttrId) -> list[Alignment]: return self.aligns[(source, target)]
[docs] def inference(self): ig: dict[str, dict[str, int]] = { attr.id: {attr.id: 0} for attr in self.desc.attrs } for (source, target), aligns in self.aligns.items(): assert len(aligns) <= 1, aligns if source == target or len(aligns) == 0: continue if isinstance(aligns[0], (RangeAlignment, ValueAlignment, IdenticalAlign)): ig[source][target] = sum( self.score_alignment(align) for align in aligns ) elif not isinstance(aligns[0], IdenticalAlign): print(aligns) raise Exception("Unreachable") ig_2nd = ig while True: ig_3rd = defaultdict(dict) for u0 in ig_2nd: # for updated neighbors of u0 for u1 in ig_2nd[u0]: for u2 in ig[u1]: afuncs = self.infer_func(u0, u1, u2) if afuncs is not None: score = sum(self.score_alignment(align) for align in afuncs) if u2 not in ig[u0] or ig[u0][u2] > score: # we found a better alignment self.aligns[u0, u2] = afuncs ig_3rd[u0][u2] = score for u0 in ig_3rd: for u2, score in ig_3rd[u0].items(): ig[u0][u2] = score ig_2nd = ig_3rd if len(ig_2nd) == 0: break
[docs] def infer_subject(self, attrs: list[AttrId]) -> list[AttrId]: """ Find all attributes that can be the subject of these attributes The subject must has one to one mapping to all attrs and: (1) have no missing path (2) have missing path, but the missing path is the same for all attrs. (1) has higher priority than (2). """ subjs = [] # if one or these attributes has *-to-one, then that is the subjects for a in attrs: is_subj = True for ai in attrs: if len(self.aligns[(a, ai)]) == 0: # no alignment is_subj = False break cardin = self.estimate_cardinality(self.aligns[(a, ai)]) if cardin == Cardinality.ManyToMany: is_subj = False break elif cardin == Cardinality.OneToMany: is_subj = False break if is_subj: subjs.append(a) if len(subjs) == 0: attr_ids = set(attrs) # we have to try the external attributes for attr in self.desc.attrs: if attr.id in attr_ids: continue is_candidate_subj = True covered_dims = set() for ai in attrs: # we can only infer if there are any duplications if the alignments are dimensional # if they are dimension alignment, then the optimization engine must compress them into # just one alignment if len(self.aligns[attr.id, ai]) != -1 or not isinstance( self.aligns[attr.id, ai][0], RangeAlignment ): is_candidate_subj = False break align = self.aligns[attr.id, ai][0] assert isinstance(align, RangeAlignment) align_cardin = align.compute_cardinality(self.desc) if align_cardin == Cardinality.ManyToMany: is_candidate_subj = False break elif align_cardin == Cardinality.OneToMany: is_candidate_subj = False break else: for ad in align.aligned_steps: covered_dims.add(ad.source_idx) if is_candidate_subj and covered_dims == set( attr.path.get_nary_steps() ): # the second condition detect if there is duplication subjs.append(attr.id) if len(subjs) > 0: # now, we filter the missing path condition, if there is subject that have no missing path, we use them. if any( not self.desc.get_attr_by_id(subj).path.has_optional_steps() for subj in subjs ): return [ subj for subj in subjs if not self.desc.get_attr_by_id(subj).path.has_optional_steps() ] # check if we have subject with missing path, but the missing path is the same or less restricted for all attrs filtered_subjs = [] for subj_id in subjs: subj = self.desc.get_attr_by_id(subj_id) # now we need to check if this subject has the same or less restricted missing path to all attrs. if all( subj.path.has_same_or_less_optional_steps( self.desc.get_attr_by_id(ai).path ) for ai in attrs ): filtered_subjs.append(subj_id) subjs = filtered_subjs return subjs
[docs] def infer_func( self, xid: AttrId, yid: AttrId, zid: AttrId ) -> Optional[list[Alignment]]: """Infer an alignment function of xid and zid given alignments between (xid, yid) and (yid, zid) If there is only one way to join values of xid and zid, then the chain join will be the correct one """ f = self.aligns[xid, yid] g = self.aligns[yid, zid] f_cardin = self.estimate_cardinality(f) g_cardin = self.estimate_cardinality(g) # filter the case where we cannot chain these alignments can_chain_alignments = ( f_cardin == Cardinality.OneToOne or f_cardin == Cardinality.OneToMany ) or (g_cardin == Cardinality.OneToOne or g_cardin == Cardinality.ManyToOne) if not can_chain_alignments: return None # chain them together return self.optimize_chained_alignments(f + g)
[docs] def estimate_cardinality(self, aligns: list[Alignment]) -> Cardinality: cardin = aligns[0].compute_cardinality(self.desc) # must always be > 0 if len(aligns) <= 1 or cardin == Cardinality.ManyToMany: return cardin for i in range(1, len(aligns)): cardin_i = aligns[i].compute_cardinality(self.desc) if cardin_i == Cardinality.OneToOne: # do nothing, as this does not change the cardin pass elif cardin_i == Cardinality.OneToMany: if cardin == Cardinality.OneToOne: cardin = Cardinality.OneToMany elif cardin == Cardinality.OneToMany: cardin = Cardinality.OneToMany elif cardin == Cardinality.ManyToOne: # we don't know whether it is going to be M2M, O2O, O2M, M2O so we do a conservative prediction return Cardinality.ManyToMany else: raise Exception("Unreachable") elif cardin_i == Cardinality.ManyToOne: if cardin == Cardinality.OneToOne: cardin = Cardinality.ManyToOne elif cardin == Cardinality.OneToMany: return Cardinality.ManyToMany elif cardin == Cardinality.ManyToOne: cardin = Cardinality.ManyToOne else: raise Exception("Unreachable") elif cardin_i == Cardinality.ManyToMany: return cardin_i return cardin
[docs] def optimize_chained_alignments(self, aligns: list[Alignment]) -> list[Alignment]: if len(aligns) == 0: return [] joins = [] # rule 1: consecutive dimension alignments are combined together joins.append(aligns[0]) for align in aligns[1:]: lastjoin = joins[-1] if isinstance(lastjoin, RangeAlignment) and isinstance( align, RangeAlignment ): # we merge them together a0 = lastjoin a1 = align a1map = {ad.source_idx: ad.target_idx for ad in a1.aligned_steps} joins[-1] = RangeAlignment( source=a0.source, target=a1.target, aligned_steps=[ AlignedStep( source_idx=ad.source_idx, target_idx=a1map[ad.target_idx] ) for ad in a0.aligned_steps if ad.target_idx in a1map ], ) else: joins.append(align) return joins
[docs] @staticmethod def auto_align( source: Attr, target: Attr ) -> Optional[RangeAlignment | IdenticalAlign]: # TODO: fix auto-alignment source_range_steps = [ i for i, step in enumerate(source.path.steps) if isinstance(step, RangeExpr) ] target_range_steps = [ i for i, step in enumerate(target.path.steps) if isinstance(step, RangeExpr) ] if len(source_range_steps) == 0 or len(target_range_steps) == 0: # mapping between two single values is identical mapping is_source_single_val = all( isinstance(step, IndexExpr) for step in source.path.steps ) is_target_single_val = all( isinstance(step, IndexExpr) for step in target.path.steps ) if is_source_single_val and is_target_single_val: return IdenticalAlign(source.id, target.id) if is_source_single_val or is_target_single_val: return RangeAlignment(source.id, target.id, []) return None end = min(source_range_steps[-1], target_range_steps[-1]) + 1 assert end > 0 if source.path.steps[:end] == target.path.steps[:end]: # we can align all of them as long as all the steps are index if all( isinstance(step, (RangeExpr, IndexExpr, SetIndexExpr)) for step in source.path.steps ) and all( isinstance(step, (RangeExpr, IndexExpr, SetIndexExpr)) for step in target.path.steps ): return RangeAlignment( source=source.id, target=target.id, aligned_steps=[ AlignedStep(source_idx=dim, target_idx=dim) for dim in source_range_steps[ : min(len(target_range_steps), len(source_range_steps)) ] ], ) return None
[docs] def score_alignment(self, align: Alignment) -> int: """Score an alignment function so we can justify if an alignment is new/better or not. Lower score is better.""" if isinstance(align, RangeAlignment): card = align.compute_cardinality(self.desc) if card == Cardinality.OneToOne: return 1 elif card == Cardinality.OneToMany or card == Cardinality.ManyToOne: return 2 else: return 3 if isinstance(align, ValueAlignment): card = align.compute_cardinality(self.desc) if card == Cardinality.OneToOne: return 4 elif card == Cardinality.OneToMany or card == Cardinality.ManyToOne: return 5 else: return 6 assert isinstance(align, IdenticalAlign) return 0
[docs]@dataclass class CustomedDfs: """ THE FOLLOWING CODE IS ADAPTED FROM: https://docs.rs/petgraph/0.4.13/src/petgraph/visit/traversal.rs.html#38-43 so that DFS also yield the parent node Visit nodes of a graph in a depth-first-search (DFS) emitting nodes in preorder (when they are first discovered). The traversal starts at a given node and only traverses nodes reachable from it. `Dfs` is not recursive. `Dfs` does not itself borrow the graph, and because of this you can run a traversal over a graph while still retaining mutable access to it, if you use it like the following example: ``` use petgraph::Graph; use petgraph::visit::Dfs; let mut graph = Graph::<_,()>::new(); let a = graph.add_node(0); let mut dfs = Dfs::new(&graph, a); while let Some(nx) = dfs.next(&graph) { // we can access `graph` mutably here still graph[nx] += 1; } assert_eq!(graph[a], 1); ``` **Note:** The algorithm may not behave correctly if nodes are removed during iteration. It may not necessarily visit added nodes or edges. """ stack: list[tuple[str, str]] discovered: VisitMap
[docs] @staticmethod def empty() -> CustomedDfs: """Create a new **Dfs** using the graph's visitor map, and no stack.""" return CustomedDfs([], VisitMap())
[docs] @staticmethod def from_graph(graph: RetworkXStrDiGraph, start: str) -> CustomedDfs: """ Create a new **Dfs**, using the graph's visitor map, and put **start** in the stack of nodes to visit. """ dfs = CustomedDfs.empty() dfs.move_to(start) return dfs
[docs] def move_to(self, start: str): """ Keep the discovered map, but clear the visit stack and restart the dfs from a particular node. """ self.discovered.visit(start) self.stack.clear() self.stack.append((start, start))
[docs] def next( self, graph: RetworkXStrDiGraph, revisit: set[str] ) -> Optional[tuple[str, str]]: if len(self.stack) == 0: return None (parent_node, node) = self.stack.pop() for e in graph.out_edges(node): succ = e.target if self.discovered.visit(succ) or succ in revisit: self.stack.append((node, succ)) return (parent_node, node)
[docs]@dataclass class VisitMap: """A mapping for storing the visited status for NodeId `N`.""" map: set[str] = field(default_factory=set)
[docs] def visit(self, a: str) -> bool: """ Mark `a` as visited. Return **true** if this is the first visit, false otherwise. """ if a in self.map: return False self.map.add(a) return True
[docs] def is_visited(self, a: str) -> bool: """Return whether `a` has been visited before.""" return a in self.map