| |
| |
|
|
| import os |
| from typing import Dict, List, Optional |
|
|
| import pytorch_lightning as pl |
| import torch |
| from torch.utils.data import Dataset |
| from torch_geometric.data import Data, DataLoader |
|
|
| |
| from mecari.featurizers.lexical import ( |
| LexicalNGramFeaturizer as LexFeaturizer, |
| Morpheme as LexMorpheme, |
| ) |
|
|
|
|
| """Data module for lexical-graph training using prebuilt .pt graphs only.""" |
|
|
|
|
| |
| class _PtGraphDataset(Dataset): |
| """Prebuilt PyG graph tensors saved as .pt per sentence. |
| |
| Each file is expected to be a dict with keys: |
| - 'graph': torch_geometric.data.Data |
| - 'source_id': str (used for split) |
| - optional: 'text' |
| """ |
|
|
| def __init__(self, files: List[str]) -> None: |
| self.files = files |
|
|
| def __len__(self) -> int: |
| return len(self.files) |
|
|
| def __getitem__(self, idx: int) -> Data: |
| path = self.files[idx] |
| obj = torch.load(path, map_location="cpu") |
| if isinstance(obj, dict) and "graph" in obj: |
| data = obj["graph"] |
| else: |
| data = obj |
| if not isinstance(data, Data): |
| raise RuntimeError(f"Invalid graph object in: {path}") |
| data.data_index = idx |
| return data |
|
|
|
|
| |
| try: |
| import torch.serialization |
| from torch_geometric.data.data import DataEdgeAttr |
|
|
| torch.serialization.add_safe_globals([DataEdgeAttr, Data]) |
| except (ImportError, AttributeError): |
| pass |
|
|
|
|
| class DataModule(pl.LightningDataModule): |
| """Loads .pt graphs and builds lexical graph features for training.""" |
|
|
| def __init__( |
| self, |
| annotations_dir: str = "annotations", |
| batch_size: int = 32, |
| num_workers: int = 0, |
| max_files: Optional[int] = None, |
| use_bidirectional_edges: bool = True, |
| annotations_override_dir: Optional[str] = None, |
| silent: bool = False, |
| lexical_feature_dim: int = 100000, |
| lexical_max_features: int = 20, |
| ) -> None: |
| super().__init__() |
| self.annotations_dir = annotations_dir |
| self.annotations_override_dir = annotations_override_dir |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
| self.max_files = max_files |
| self.use_bidirectional_edges = True |
| self.silent = silent |
| self.lexical_feature_dim = lexical_feature_dim |
| self.lexical_max_features = int(lexical_max_features) |
| self.use_bidirectional_edges = bool(use_bidirectional_edges) |
|
|
| |
| self.train_dataset = [] |
| self.val_dataset = [] |
| self.test_dataset = [] |
| |
| self._lex_featurizer = LexFeaturizer(dim=int(self.lexical_feature_dim), add_bias=True) |
| |
| self.pos_to_id = { |
| "名詞": 1, |
| "動詞": 2, |
| "形容詞": 3, |
| "副詞": 4, |
| "助詞": 5, |
| "助動詞": 6, |
| "接続詞": 7, |
| "連体詞": 8, |
| "感動詞": 9, |
| "形状詞": 10, |
| "補助記号": 11, |
| "接頭辞": 12, |
| "接尾辞": 13, |
| "特殊": 14, |
| } |
| self.id_to_pos = {v: k for k, v in self.pos_to_id.items()} |
|
|
| def create_graph_from_morphemes_data(self, *args, **kwargs) -> Optional[Data]: |
| """Create a lexical graph from morpheme data (or candidates).""" |
| if "candidates" in kwargs: |
| candidates = kwargs.pop("candidates") |
| text = kwargs.get("text", "") |
| morphemes_edges = self._build_graph_from_candidates(candidates, text) |
| if not morphemes_edges: |
| return None |
| kwargs["morphemes"] = morphemes_edges["morphemes"] |
| kwargs["edges"] = morphemes_edges["edges"] |
| return self._create_lexical_graph(*args, **kwargs) |
|
|
| |
| def compute_lexical_features(self, morphemes: List[Dict], text: str) -> List[Dict]: |
| """Add lexical_features to each morpheme using Mecari's lexical featurizer. |
| |
| Requires mecari.featurizers.lexical to be importable. Raises a clear error |
| if the featurizer is unavailable (training/inference depend on it). |
| """ |
| if not morphemes: |
| return morphemes |
|
|
| for m in morphemes: |
| try: |
| morph_obj = LexMorpheme( |
| surf=m.get("surface", ""), |
| lemma=m.get("base_form", ""), |
| pos=m.get("pos", "*"), |
| pos1=m.get("pos_detail1", "*"), |
| ctype=m.get("inflection_type", "*"), |
| cform=m.get("inflection_form", "*"), |
| reading=m.get("reading", "*"), |
| ) |
| st = m.get("start_pos", 0) |
| ed = m.get("end_pos", st + len(m.get("surface", ""))) |
| prev_char = text[st - 1] if st > 0 else None |
| next_char = text[ed] if ed < len(text) else None |
| feats = self._lex_featurizer.unigram_feats(morph_obj, prev_char, next_char) |
| m["lexical_features"] = feats |
| except Exception: |
| |
| pass |
| return morphemes |
|
|
| def _create_lexical_graph( |
| self, morphemes: List[Dict], edges: List[Dict], text: str, for_training: bool = True |
| ) -> Optional[Data]: |
| """Build a graph using lexical features.""" |
| if not morphemes: |
| return None |
|
|
| |
| all_indices = [] |
| all_values = [] |
| all_lengths = [] |
| annotations = [] |
| valid_mask = [] |
|
|
| max_features = 0 |
| for morpheme in morphemes: |
| lexical_feats = morpheme.get("lexical_features", []) |
| indices = [] |
| values = [] |
| for idx, val in lexical_feats: |
| if 0 <= idx < self.lexical_feature_dim: |
| indices.append(idx) |
| values.append(val) |
| all_lengths.append(len(indices)) |
| max_features = max(max_features, len(indices)) |
|
|
| all_indices.append(indices) |
| all_values.append(values) |
|
|
| if for_training: |
| annotation = morpheme.get("annotation", "?") |
| if annotation == "+": |
| annotations.append(1) |
| valid_mask.append(True) |
| elif annotation == "-": |
| annotations.append(0) |
| valid_mask.append(True) |
| else: |
| annotations.append(0) |
| valid_mask.append(False) |
|
|
| |
| FIXED_MAX_FEATURES = int(getattr(self, "lexical_max_features", 20)) |
|
|
| padded_indices = [] |
| padded_values = [] |
| for indices, values in zip(all_indices, all_values): |
| if len(indices) > FIXED_MAX_FEATURES: |
| padded_indices.append(indices[:FIXED_MAX_FEATURES]) |
| padded_values.append(values[:FIXED_MAX_FEATURES]) |
| else: |
| pad_length = FIXED_MAX_FEATURES - len(indices) |
| padded_indices.append(indices + [0] * pad_length) |
| padded_values.append(values + [0.0] * pad_length) |
|
|
| edge_index = self._build_edge_index(edges, len(morphemes)) |
|
|
| |
| pos_ids = [] |
| for m in morphemes: |
| pos = m.get("pos", "*") |
| pos_ids.append(self.pos_to_id.get(pos, 0)) |
|
|
| graph_data = Data( |
| lexical_indices=torch.tensor(padded_indices, dtype=torch.long), |
| lexical_values=torch.tensor(padded_values, dtype=torch.float32), |
| lexical_lengths=torch.tensor(all_lengths, dtype=torch.long), |
| edge_index=edge_index, |
| num_nodes=len(morphemes), |
| ) |
| graph_data.pos_ids = torch.tensor(pos_ids, dtype=torch.long) |
| if for_training: |
| graph_data.y = torch.tensor(annotations, dtype=torch.float32) |
| graph_data.valid_mask = torch.tensor(valid_mask, dtype=torch.bool) |
|
|
| return graph_data |
|
|
| def _build_edge_index(self, edges: List[Dict], num_nodes: int) -> torch.Tensor: |
| """Build a PyG edge_index tensor from edge dicts.""" |
| if not edges: |
| return torch.tensor([[], []], dtype=torch.long) |
|
|
| source_indices = [] |
| target_indices = [] |
|
|
| for edge in edges: |
| source = edge.get("source_idx", 0) |
| target = edge.get("target_idx", 0) |
|
|
| if 0 <= source < num_nodes and 0 <= target < num_nodes: |
| source_indices.append(source) |
| target_indices.append(target) |
| if self.use_bidirectional_edges: |
| source_indices.append(target) |
| target_indices.append(source) |
|
|
| if not source_indices: |
| return torch.tensor([[], []], dtype=torch.long) |
|
|
| return torch.tensor([source_indices, target_indices], dtype=torch.long) |
|
|
| def _load_kwdlc_ids(self, ids_file: str) -> set: |
| """Load KWDLC ID list (one ID per line).""" |
| ids = set() |
| if ids_file and os.path.exists(ids_file): |
| with open(ids_file, "r") as f: |
| for line in f: |
| ids.add(line.strip()) |
| return ids |
|
|
| def load_annotation_data(self, max_files: Optional[int] = None) -> List[Dict]: |
| """Detect and list available .pt annotation graph files.""" |
| if os.path.isdir(self.annotations_dir): |
| pt_files = [ |
| os.path.join(self.annotations_dir, fn) |
| for fn in sorted(os.listdir(self.annotations_dir)) |
| if fn.endswith(".pt") |
| ] |
| if pt_files: |
| if max_files is not None: |
| pt_files = pt_files[:max_files] |
| return [{"_mode": "pt", "_pt_files": pt_files}] |
| raise FileNotFoundError(f"No annotation graphs found under: {self.annotations_dir}") |
|
|
| def setup(self, stage: Optional[str] = None) -> None: |
| """Build train/val/test datasets from discovered .pt files.""" |
| annotation_data = self.load_annotation_data(max_files=self.max_files) |
|
|
| if not annotation_data: |
| self.train_dataset = [] |
| self.val_dataset = [] |
| self.test_dataset = [] |
| return |
|
|
| dev_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "dev.id")) |
| test_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "test.id")) |
|
|
| mode = annotation_data[0].get("_mode") |
| if mode == "pt": |
| files: List[str] = annotation_data[0]["_pt_files"] |
| train_files: List[str] = [] |
| val_files: List[str] = [] |
| test_files: List[str] = [] |
|
|
| |
| dev_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "dev.id")) |
| test_ids = self._load_kwdlc_ids(os.path.join("KWDLC", "id", "split_for_pas", "test.id")) |
|
|
| for fp in files: |
| sid = None |
| try: |
| obj = torch.load(fp, map_location="cpu") |
| if isinstance(obj, dict): |
| sid = obj.get("source_id") |
| except Exception: |
| pass |
| if sid and (dev_ids or test_ids): |
| if sid in test_ids: |
| test_files.append(fp) |
| elif sid in dev_ids: |
| val_files.append(fp) |
| else: |
| train_files.append(fp) |
| else: |
| train_files.append(fp) |
|
|
| |
| self.train_dataset = _PtGraphDataset(train_files) |
| self.val_dataset = _PtGraphDataset(val_files) |
| self.test_dataset = _PtGraphDataset(test_files) |
|
|
| if len(self.val_dataset) == 0 or len(self.test_dataset) == 0: |
| raise RuntimeError( |
| "KWDLC dev/test split produced empty val/test datasets. Ensure KWDLC id files exist and source_id is set in .pt files." |
| ) |
| else: |
| raise RuntimeError("Unsupported annotation mode; expected pt") |
|
|
| print( |
| f"Data split: train={len(self.train_dataset)}, val={len(self.val_dataset)}, test={len(self.test_dataset)}" |
| ) |
|
|
| def _create_dataloader(self, dataset: List[Data], batch_size: int, shuffle: bool = False) -> DataLoader: |
| """Create a DataLoader with optional workers/prefetching.""" |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=self.num_workers, |
| pin_memory=False, |
| persistent_workers=True if self.num_workers > 0 else False, |
| prefetch_factor=2 if self.num_workers > 0 else None, |
| ) |
|
|
| def train_dataloader(self) -> DataLoader: |
| """Return train DataLoader.""" |
| return self._create_dataloader(self.train_dataset, self.batch_size, shuffle=True) |
|
|
| def val_dataloader(self) -> DataLoader: |
| """Return val DataLoader.""" |
| return self._create_dataloader(self.val_dataset, self.batch_size, shuffle=False) |
|
|
| def test_dataloader(self) -> DataLoader: |
| """Return test DataLoader.""" |
| return self._create_dataloader(self.test_dataset, self.batch_size, shuffle=False) |
|
|