docs: add docstrings/return types

This commit is contained in:
2024-08-03 02:42:31 +03:00
parent 5413180b05
commit bb616258a6
2 changed files with 153 additions and 48 deletions

View File

@@ -7,15 +7,15 @@ import path from 'node:path'
import os from 'node:os'
export class MegaHAL {
learning: boolean
seed: SoothPredictor
fore: SoothPredictor
back: SoothPredictor
case: SoothPredictor
punc: SoothPredictor
brain: Record<string, number>
public learning: boolean
public seed: SoothPredictor
public fore: SoothPredictor
public back: SoothPredictor
public case: SoothPredictor
public punc: SoothPredictor
public brain: Record<string, number>
dictionary: Record<string, number>
public dictionary: Record<string, number>
constructor(personality?: string) {
this.learning = true
@@ -29,6 +29,9 @@ export class MegaHAL {
this.become(personality || 'default')
}
/**
* Clears all predictors and dictionaries.
*/
public clear(): void {
this.seed.clear()
this.fore.clear()
@@ -39,10 +42,16 @@ export class MegaHAL {
this.brain = {}
}
/**
* Returns the length of a map object.
*/
public mapLength(map: Record<string, unknown>): number {
return Object.keys(map).length
}
/**
* Trains the model using a file or an array of lines.
*/
// eslint-disable-next-line no-unused-vars
public async train(filename: string): Promise<void>
// eslint-disable-next-line no-unused-vars
@@ -57,7 +66,10 @@ export class MegaHAL {
}
}
private _train(data: string[]) {
/**
* Internal training method to process and learn from data lines.
*/
private _train(data: string[]): void {
data = data.map((x) => x.trim()).filter(Boolean)
for (const line of data) {
const [puncs, norms, words] = this._decompose(line)
@@ -67,6 +79,9 @@ export class MegaHAL {
private static personalities: Record<string, string[]> = {}
/**
* Adds a new personality with the given name and data.
*/
public static addPersonality(name: string, data: string[]): void {
if (this.personalities[name]) {
return
@@ -74,10 +89,16 @@ export class MegaHAL {
this.personalities[name] = data
}
/**
* Lists all available personalities.
*/
public static list(): string[] {
return Object.keys(this.personalities)
}
/**
* Saves the current state to a file.
*/
public async save(filename: string): Promise<boolean> {
try {
const zip = new JSZip()
@@ -105,6 +126,9 @@ export class MegaHAL {
}
}
/**
* Loads the state from a file.
*/
public async load(filename: string): Promise<void> {
const zip = new JSZip()
const data = await fs.readFile(filename)
@@ -118,6 +142,7 @@ export class MegaHAL {
.file('dictionary')
?.async('string')
.then((x) => JSON.parse(x))
this.learning = dict.learning
this.brain = dict.brain
this.dictionary = dict.dictionary
@@ -134,6 +159,9 @@ export class MegaHAL {
}
}
/**
* Changes the current personality and clears previous state.
*/
public become(name = 'default'): void {
if (!MegaHAL.personalities[name]) {
throw new Error('No such personality')
@@ -142,10 +170,16 @@ export class MegaHAL {
this._train(MegaHAL.personalities[name])
}
private _getBrain(context: number[]) {
/**
* Retrieves or creates a brain context.
*/
private _getBrain(context: number[]): number {
return (this.brain[contextHash(context)] ??= this.mapLength(this.brain))
}
/**
* Generates a reply for the given input string.
*/
public reply(input: string, error = '...'): string {
const [puncs, norms, words] = this._decompose(input?.trim())
@@ -158,6 +192,7 @@ export class MegaHAL {
}
utterances.push(this._generate([]))
// Filter out exact matches and null values
utterances = utterances.filter((u) => inputSymbols.join(',') !== u?.join(',')).filter(notNull)
let reply: string | null = null
@@ -180,7 +215,10 @@ export class MegaHAL {
return reply || error
}
private _learn(puncs: string[], norms: string[], words: string[]) {
/**
* Learns from punctuation, normalized, and word symbols.
*/
private _learn(puncs: string[], norms: string[], words: string[]): void {
if (!words.length) return
const puncSyms = puncs.map((p) => (this.dictionary[p] ||= this.mapLength(this.dictionary)))
@@ -242,7 +280,10 @@ export class MegaHAL {
}
}
private _selectUtterance(utterances: (number[] | null)[], kwSymbols: number[]) {
/**
* Selects the best utterance based on keyword symbols.
*/
private _selectUtterance(utterances: (number[] | null)[], kwSymbols: number[]): number[] | null {
let bestScore = -1
let bestUtterance: number[] | null = null
for (const utterance of utterances) {
@@ -257,6 +298,9 @@ export class MegaHAL {
return bestUtterance
}
/**
* Calculates the score of an utterance based on keyword symbols.
*/
private _calculateScore(utterance: number[] | null, kwSymbols: number[]): number {
let score = 0
let context = [1, 1]
@@ -298,12 +342,18 @@ export class MegaHAL {
return score
}
private _generate(kwSymbols: number[]) {
/**
* Generates an utterance based on keyword symbols.
*/
private _generate(kwSymbols: number[]): number[] | null {
const result = this._getResult(kwSymbols)
return !result?.length ? null : result
}
private _getResult(kwSymbols: number[]) {
/**
* Gets a result based on keyword symbols.
*/
private _getResult(kwSymbols: number[]): number[] | null {
const keyword = this._selectKeyword(kwSymbols)
if (keyword) {
const contexts = [
@@ -336,12 +386,18 @@ export class MegaHAL {
return this._randomWalk(this.fore, context, kwSymbols)
}
private _selectKeyword(kwSymbols: number[]) {
/**
* Selects a keyword from the keyword symbols.
*/
private _selectKeyword(kwSymbols: number[]): number | undefined {
const aux = Keywords.AUXILIARY.map((a) => this.dictionary[a])
const syms = kwSymbols.filter((s) => !aux.includes(s))
return syms[Math.floor(Math.random() * syms.length)]
}
/**
* Performs a random walk to generate symbols based on the model.
*/
private _randomWalk(
model: SoothPredictor,
staticContext: number[],
@@ -376,7 +432,13 @@ export class MegaHAL {
return results
}
private _decompose(line: string | undefined | null, maxLen = 1024) {
/**
* Decomposes a line of text into punctuation, normalized, and word symbols.
*/
private _decompose(
line: string | undefined | null,
maxLen = 1024,
): [string[] | null, string[] | null, string[] | null] {
if (!line) return [null, null, null]
if (line.length > maxLen) {
@@ -393,7 +455,10 @@ export class MegaHAL {
return [puncs, norms, words]
}
private _segment(line: string) {
/**
* Segments a line of text into punctuation and words.
*/
private _segment(line: string): [string[], string[]] {
let sequence = this._characterSegmentation(line) ? line.split(/(\w)/) : line.split(/(\w+)/)
if (/\w+/.test(sequence[sequence.length - 1])) {
@@ -404,6 +469,7 @@ export class MegaHAL {
sequence.unshift('')
}
// Combine hyphenated words
while (true) {
const index = sequence.slice(1, -1).findIndex((item) => /^['-]$/.test(item))
if (index === -1) break
@@ -428,6 +494,9 @@ export class MegaHAL {
return [separators, words]
}
/**
* Rewrites a sequence of normalized symbols into a sentence.
*/
private _rewrite(normSymbols: number[]): string | null {
const decode: Record<number, string> = Object.fromEntries(
Object.entries(this.dictionary).map((x) => x.reverse()),
@@ -481,7 +550,10 @@ export class MegaHAL {
.join('')
}
private _characterSegmentation(_line: string) {
/**
* Checks if a line requires character segmentation.
*/
private _characterSegmentation(_line: string): boolean {
// TODO implement more languages
return false
}

View File

@@ -13,33 +13,26 @@ export interface SoothStatistic {
}
export class SoothPredictor {
errorEvent: number
contexts: SoothContext[] = []
contextsSize = 0
public errorEvent: number
private contexts: SoothContext[] = []
private contextsSize = 0
constructor(errorEvent = 0) {
this.errorEvent = errorEvent
}
clear() {
/**
* Clears all contexts.
*/
public clear(): void {
this.contexts = []
this.contextsSize = 0
}
serializeStatistic(statistic: SoothStatistic): Buffer {
const buffer = Buffer.alloc(8)
buffer.writeInt32LE(statistic.event, 0)
buffer.writeInt32LE(statistic.count, 4)
return buffer
}
parseStatistic(buffer: Buffer, offset: number): SoothStatistic {
const event = buffer.readInt32LE(offset)
const count = buffer.readInt32LE(offset + 4)
return { event, count }
}
save(filename: string): boolean {
/**
* Saves the predictor state to a file.
*/
public save(filename: string): boolean {
try {
const file = fs.openSync(filename, 'w')
@@ -59,6 +52,7 @@ export class SoothPredictor {
contextBuffer.writeInt32LE(context.statisticsSize, 8)
fs.writeSync(file, contextBuffer)
// Write each statistic
for (let j = 0; j < context.statisticsSize; j++) {
const statisticBuffer = Buffer.alloc(8)
statisticBuffer.writeInt32LE(context.statistics[j].event, 0)
@@ -75,7 +69,10 @@ export class SoothPredictor {
}
}
load(filename: string): boolean {
/**
* Loads the predictor state from a file.
*/
public load(filename: string): boolean {
try {
const fileBuffer = fs.readFileSync(filename)
@@ -121,12 +118,16 @@ export class SoothPredictor {
}
}
findContext(id: number): SoothContext {
/**
* Finds or creates a context for the given ID.
*/
private findContext(id: number): SoothContext {
let context: SoothContext | undefined
let low = 0
let mid = 0
let high = this.contextsSize - 1
// Binary search for the context
if (this.contextsSize > 0) {
while (low <= high) {
mid = Math.floor(low + (high - low) / 2)
@@ -145,6 +146,7 @@ export class SoothPredictor {
mid = low
}
// Create a new context if not found
this.contextsSize += 1
this.contexts.push({ id: -1, count: 0, statisticsSize: 0, statistics: [] })
@@ -164,12 +166,16 @@ export class SoothPredictor {
return context
}
findStatistic(context: SoothContext, event: number): SoothStatistic {
/**
* Finds or creates a statistic for the given event within the context.
*/
private findStatistic(context: SoothContext, event: number): SoothStatistic {
let low = 0
let high = context.statisticsSize - 1
let mid = 0
let statistic: SoothStatistic | null = null
// Binary search for the statistic
if (context.statisticsSize > 0) {
while (low <= high) {
mid = low + Math.floor((high - low) / 2)
@@ -177,7 +183,7 @@ export class SoothPredictor {
if (statistic.event < event) {
low = mid + 1
} else if (statistic.event > event) {
if (mid == 0) {
if (mid === 0) {
break
}
high = mid - 1
@@ -189,6 +195,7 @@ export class SoothPredictor {
mid = low
}
// Create a new statistic if not found
context.statisticsSize += 1
const newMemory = new Array<SoothStatistic>(context.statisticsSize)
@@ -210,19 +217,29 @@ export class SoothPredictor {
return statistic
}
size(id: number) {
/**
* Returns the size of the statistics for the given context ID.
*/
public size(id: number): number {
const context = this.findContext(id)
return context.statisticsSize
}
count(id: number) {
/**
* Returns the count of observations for the given context ID.
*/
public count(id: number): number {
const context = this.findContext(id)
return context.count
}
observe(id: number, event: number) {
/**
* Observes an event for the given context ID.
*/
public observe(id: number, event: number): number {
const context = this.findContext(id)
// Handle overflow by halving the counts
if (context.count === Number.MAX_SAFE_INTEGER) {
context.count = 0
for (let i = 0; i < context.statisticsSize; i++) {
@@ -240,7 +257,10 @@ export class SoothPredictor {
return statistic.count
}
select(id: number, limit: number) {
/**
* Selects an event based on the limit for the given context ID.
*/
public select(id: number, limit: number): number {
const context = this.findContext(id)
if (limit === 0 || limit > context.count) {
return this.errorEvent
@@ -258,9 +278,13 @@ export class SoothPredictor {
return this.errorEvent
}
distribution(id: number) {
/**
* Returns the probability distribution of events for the given context ID.
*/
public distribution(id: number): Record<number, number> | null {
const context = this.findContext(id)
if (!context.statisticsSize) return null
const total = context.count
return context.statistics.reduce(
(acc, stat) => {
@@ -271,7 +295,10 @@ export class SoothPredictor {
)
}
uncertainty(id: number) {
/**
* Returns the uncertainty (entropy) for the given context ID.
*/
public uncertainty(id: number): number | null {
const context = this.findContext(id)
if (!context.statisticsSize) return null
@@ -286,7 +313,10 @@ export class SoothPredictor {
return uncertainty
}
surprise(id: number, event: number) {
/**
* Returns the surprise for the given event and context ID.
*/
public surprise(id: number, event: number): number | null {
const context = this.findContext(id)
if (context.count === 0) {
return null
@@ -303,7 +333,10 @@ export class SoothPredictor {
return Object.is(surpriseValue, -0) ? 0 : surpriseValue
}
frequency(id: number, event: number) {
/**
* Returns the frequency of the given event for the context ID.
*/
public frequency(id: number, event: number): number {
const context = this.findContext(id)
if (context.count == 0) {
return 0