mirror of
https://github.com/chenasraf/megahal.js.git
synced 2026-05-17 17:48:02 +00:00
docs: add docstrings/return types
This commit is contained in:
108
src/megahal.ts
108
src/megahal.ts
@@ -7,15 +7,15 @@ import path from 'node:path'
|
||||
import os from 'node:os'
|
||||
|
||||
export class MegaHAL {
|
||||
learning: boolean
|
||||
seed: SoothPredictor
|
||||
fore: SoothPredictor
|
||||
back: SoothPredictor
|
||||
case: SoothPredictor
|
||||
punc: SoothPredictor
|
||||
brain: Record<string, number>
|
||||
public learning: boolean
|
||||
public seed: SoothPredictor
|
||||
public fore: SoothPredictor
|
||||
public back: SoothPredictor
|
||||
public case: SoothPredictor
|
||||
public punc: SoothPredictor
|
||||
public brain: Record<string, number>
|
||||
|
||||
dictionary: Record<string, number>
|
||||
public dictionary: Record<string, number>
|
||||
|
||||
constructor(personality?: string) {
|
||||
this.learning = true
|
||||
@@ -29,6 +29,9 @@ export class MegaHAL {
|
||||
this.become(personality || 'default')
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears all predictors and dictionaries.
|
||||
*/
|
||||
public clear(): void {
|
||||
this.seed.clear()
|
||||
this.fore.clear()
|
||||
@@ -39,10 +42,16 @@ export class MegaHAL {
|
||||
this.brain = {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the length of a map object.
|
||||
*/
|
||||
public mapLength(map: Record<string, unknown>): number {
|
||||
return Object.keys(map).length
|
||||
}
|
||||
|
||||
/**
|
||||
* Trains the model using a file or an array of lines.
|
||||
*/
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
public async train(filename: string): Promise<void>
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
@@ -57,7 +66,10 @@ export class MegaHAL {
|
||||
}
|
||||
}
|
||||
|
||||
private _train(data: string[]) {
|
||||
/**
|
||||
* Internal training method to process and learn from data lines.
|
||||
*/
|
||||
private _train(data: string[]): void {
|
||||
data = data.map((x) => x.trim()).filter(Boolean)
|
||||
for (const line of data) {
|
||||
const [puncs, norms, words] = this._decompose(line)
|
||||
@@ -67,6 +79,9 @@ export class MegaHAL {
|
||||
|
||||
private static personalities: Record<string, string[]> = {}
|
||||
|
||||
/**
|
||||
* Adds a new personality with the given name and data.
|
||||
*/
|
||||
public static addPersonality(name: string, data: string[]): void {
|
||||
if (this.personalities[name]) {
|
||||
return
|
||||
@@ -74,10 +89,16 @@ export class MegaHAL {
|
||||
this.personalities[name] = data
|
||||
}
|
||||
|
||||
/**
|
||||
* Lists all available personalities.
|
||||
*/
|
||||
public static list(): string[] {
|
||||
return Object.keys(this.personalities)
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves the current state to a file.
|
||||
*/
|
||||
public async save(filename: string): Promise<boolean> {
|
||||
try {
|
||||
const zip = new JSZip()
|
||||
@@ -105,6 +126,9 @@ export class MegaHAL {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the state from a file.
|
||||
*/
|
||||
public async load(filename: string): Promise<void> {
|
||||
const zip = new JSZip()
|
||||
const data = await fs.readFile(filename)
|
||||
@@ -118,6 +142,7 @@ export class MegaHAL {
|
||||
.file('dictionary')
|
||||
?.async('string')
|
||||
.then((x) => JSON.parse(x))
|
||||
|
||||
this.learning = dict.learning
|
||||
this.brain = dict.brain
|
||||
this.dictionary = dict.dictionary
|
||||
@@ -134,6 +159,9 @@ export class MegaHAL {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Changes the current personality and clears previous state.
|
||||
*/
|
||||
public become(name = 'default'): void {
|
||||
if (!MegaHAL.personalities[name]) {
|
||||
throw new Error('No such personality')
|
||||
@@ -142,10 +170,16 @@ export class MegaHAL {
|
||||
this._train(MegaHAL.personalities[name])
|
||||
}
|
||||
|
||||
private _getBrain(context: number[]) {
|
||||
/**
|
||||
* Retrieves or creates a brain context.
|
||||
*/
|
||||
private _getBrain(context: number[]): number {
|
||||
return (this.brain[contextHash(context)] ??= this.mapLength(this.brain))
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a reply for the given input string.
|
||||
*/
|
||||
public reply(input: string, error = '...'): string {
|
||||
const [puncs, norms, words] = this._decompose(input?.trim())
|
||||
|
||||
@@ -158,6 +192,7 @@ export class MegaHAL {
|
||||
}
|
||||
utterances.push(this._generate([]))
|
||||
|
||||
// Filter out exact matches and null values
|
||||
utterances = utterances.filter((u) => inputSymbols.join(',') !== u?.join(',')).filter(notNull)
|
||||
|
||||
let reply: string | null = null
|
||||
@@ -180,7 +215,10 @@ export class MegaHAL {
|
||||
return reply || error
|
||||
}
|
||||
|
||||
private _learn(puncs: string[], norms: string[], words: string[]) {
|
||||
/**
|
||||
* Learns from punctuation, normalized, and word symbols.
|
||||
*/
|
||||
private _learn(puncs: string[], norms: string[], words: string[]): void {
|
||||
if (!words.length) return
|
||||
|
||||
const puncSyms = puncs.map((p) => (this.dictionary[p] ||= this.mapLength(this.dictionary)))
|
||||
@@ -242,7 +280,10 @@ export class MegaHAL {
|
||||
}
|
||||
}
|
||||
|
||||
private _selectUtterance(utterances: (number[] | null)[], kwSymbols: number[]) {
|
||||
/**
|
||||
* Selects the best utterance based on keyword symbols.
|
||||
*/
|
||||
private _selectUtterance(utterances: (number[] | null)[], kwSymbols: number[]): number[] | null {
|
||||
let bestScore = -1
|
||||
let bestUtterance: number[] | null = null
|
||||
for (const utterance of utterances) {
|
||||
@@ -257,6 +298,9 @@ export class MegaHAL {
|
||||
return bestUtterance
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the score of an utterance based on keyword symbols.
|
||||
*/
|
||||
private _calculateScore(utterance: number[] | null, kwSymbols: number[]): number {
|
||||
let score = 0
|
||||
let context = [1, 1]
|
||||
@@ -298,12 +342,18 @@ export class MegaHAL {
|
||||
return score
|
||||
}
|
||||
|
||||
private _generate(kwSymbols: number[]) {
|
||||
/**
|
||||
* Generates an utterance based on keyword symbols.
|
||||
*/
|
||||
private _generate(kwSymbols: number[]): number[] | null {
|
||||
const result = this._getResult(kwSymbols)
|
||||
return !result?.length ? null : result
|
||||
}
|
||||
|
||||
private _getResult(kwSymbols: number[]) {
|
||||
/**
|
||||
* Gets a result based on keyword symbols.
|
||||
*/
|
||||
private _getResult(kwSymbols: number[]): number[] | null {
|
||||
const keyword = this._selectKeyword(kwSymbols)
|
||||
if (keyword) {
|
||||
const contexts = [
|
||||
@@ -336,12 +386,18 @@ export class MegaHAL {
|
||||
return this._randomWalk(this.fore, context, kwSymbols)
|
||||
}
|
||||
|
||||
private _selectKeyword(kwSymbols: number[]) {
|
||||
/**
|
||||
* Selects a keyword from the keyword symbols.
|
||||
*/
|
||||
private _selectKeyword(kwSymbols: number[]): number | undefined {
|
||||
const aux = Keywords.AUXILIARY.map((a) => this.dictionary[a])
|
||||
const syms = kwSymbols.filter((s) => !aux.includes(s))
|
||||
return syms[Math.floor(Math.random() * syms.length)]
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a random walk to generate symbols based on the model.
|
||||
*/
|
||||
private _randomWalk(
|
||||
model: SoothPredictor,
|
||||
staticContext: number[],
|
||||
@@ -376,7 +432,13 @@ export class MegaHAL {
|
||||
return results
|
||||
}
|
||||
|
||||
private _decompose(line: string | undefined | null, maxLen = 1024) {
|
||||
/**
|
||||
* Decomposes a line of text into punctuation, normalized, and word symbols.
|
||||
*/
|
||||
private _decompose(
|
||||
line: string | undefined | null,
|
||||
maxLen = 1024,
|
||||
): [string[] | null, string[] | null, string[] | null] {
|
||||
if (!line) return [null, null, null]
|
||||
|
||||
if (line.length > maxLen) {
|
||||
@@ -393,7 +455,10 @@ export class MegaHAL {
|
||||
return [puncs, norms, words]
|
||||
}
|
||||
|
||||
private _segment(line: string) {
|
||||
/**
|
||||
* Segments a line of text into punctuation and words.
|
||||
*/
|
||||
private _segment(line: string): [string[], string[]] {
|
||||
let sequence = this._characterSegmentation(line) ? line.split(/(\w)/) : line.split(/(\w+)/)
|
||||
|
||||
if (/\w+/.test(sequence[sequence.length - 1])) {
|
||||
@@ -404,6 +469,7 @@ export class MegaHAL {
|
||||
sequence.unshift('')
|
||||
}
|
||||
|
||||
// Combine hyphenated words
|
||||
while (true) {
|
||||
const index = sequence.slice(1, -1).findIndex((item) => /^['-]$/.test(item))
|
||||
if (index === -1) break
|
||||
@@ -428,6 +494,9 @@ export class MegaHAL {
|
||||
return [separators, words]
|
||||
}
|
||||
|
||||
/**
|
||||
* Rewrites a sequence of normalized symbols into a sentence.
|
||||
*/
|
||||
private _rewrite(normSymbols: number[]): string | null {
|
||||
const decode: Record<number, string> = Object.fromEntries(
|
||||
Object.entries(this.dictionary).map((x) => x.reverse()),
|
||||
@@ -481,7 +550,10 @@ export class MegaHAL {
|
||||
.join('')
|
||||
}
|
||||
|
||||
private _characterSegmentation(_line: string) {
|
||||
/**
|
||||
* Checks if a line requires character segmentation.
|
||||
*/
|
||||
private _characterSegmentation(_line: string): boolean {
|
||||
// TODO implement more languages
|
||||
return false
|
||||
}
|
||||
|
||||
93
src/sooth.ts
93
src/sooth.ts
@@ -13,33 +13,26 @@ export interface SoothStatistic {
|
||||
}
|
||||
|
||||
export class SoothPredictor {
|
||||
errorEvent: number
|
||||
contexts: SoothContext[] = []
|
||||
contextsSize = 0
|
||||
public errorEvent: number
|
||||
private contexts: SoothContext[] = []
|
||||
private contextsSize = 0
|
||||
|
||||
constructor(errorEvent = 0) {
|
||||
this.errorEvent = errorEvent
|
||||
}
|
||||
|
||||
clear() {
|
||||
/**
|
||||
* Clears all contexts.
|
||||
*/
|
||||
public clear(): void {
|
||||
this.contexts = []
|
||||
this.contextsSize = 0
|
||||
}
|
||||
|
||||
serializeStatistic(statistic: SoothStatistic): Buffer {
|
||||
const buffer = Buffer.alloc(8)
|
||||
buffer.writeInt32LE(statistic.event, 0)
|
||||
buffer.writeInt32LE(statistic.count, 4)
|
||||
return buffer
|
||||
}
|
||||
|
||||
parseStatistic(buffer: Buffer, offset: number): SoothStatistic {
|
||||
const event = buffer.readInt32LE(offset)
|
||||
const count = buffer.readInt32LE(offset + 4)
|
||||
return { event, count }
|
||||
}
|
||||
|
||||
save(filename: string): boolean {
|
||||
/**
|
||||
* Saves the predictor state to a file.
|
||||
*/
|
||||
public save(filename: string): boolean {
|
||||
try {
|
||||
const file = fs.openSync(filename, 'w')
|
||||
|
||||
@@ -59,6 +52,7 @@ export class SoothPredictor {
|
||||
contextBuffer.writeInt32LE(context.statisticsSize, 8)
|
||||
fs.writeSync(file, contextBuffer)
|
||||
|
||||
// Write each statistic
|
||||
for (let j = 0; j < context.statisticsSize; j++) {
|
||||
const statisticBuffer = Buffer.alloc(8)
|
||||
statisticBuffer.writeInt32LE(context.statistics[j].event, 0)
|
||||
@@ -75,7 +69,10 @@ export class SoothPredictor {
|
||||
}
|
||||
}
|
||||
|
||||
load(filename: string): boolean {
|
||||
/**
|
||||
* Loads the predictor state from a file.
|
||||
*/
|
||||
public load(filename: string): boolean {
|
||||
try {
|
||||
const fileBuffer = fs.readFileSync(filename)
|
||||
|
||||
@@ -121,12 +118,16 @@ export class SoothPredictor {
|
||||
}
|
||||
}
|
||||
|
||||
findContext(id: number): SoothContext {
|
||||
/**
|
||||
* Finds or creates a context for the given ID.
|
||||
*/
|
||||
private findContext(id: number): SoothContext {
|
||||
let context: SoothContext | undefined
|
||||
let low = 0
|
||||
let mid = 0
|
||||
let high = this.contextsSize - 1
|
||||
|
||||
// Binary search for the context
|
||||
if (this.contextsSize > 0) {
|
||||
while (low <= high) {
|
||||
mid = Math.floor(low + (high - low) / 2)
|
||||
@@ -145,6 +146,7 @@ export class SoothPredictor {
|
||||
mid = low
|
||||
}
|
||||
|
||||
// Create a new context if not found
|
||||
this.contextsSize += 1
|
||||
this.contexts.push({ id: -1, count: 0, statisticsSize: 0, statistics: [] })
|
||||
|
||||
@@ -164,12 +166,16 @@ export class SoothPredictor {
|
||||
return context
|
||||
}
|
||||
|
||||
findStatistic(context: SoothContext, event: number): SoothStatistic {
|
||||
/**
|
||||
* Finds or creates a statistic for the given event within the context.
|
||||
*/
|
||||
private findStatistic(context: SoothContext, event: number): SoothStatistic {
|
||||
let low = 0
|
||||
let high = context.statisticsSize - 1
|
||||
let mid = 0
|
||||
let statistic: SoothStatistic | null = null
|
||||
|
||||
// Binary search for the statistic
|
||||
if (context.statisticsSize > 0) {
|
||||
while (low <= high) {
|
||||
mid = low + Math.floor((high - low) / 2)
|
||||
@@ -177,7 +183,7 @@ export class SoothPredictor {
|
||||
if (statistic.event < event) {
|
||||
low = mid + 1
|
||||
} else if (statistic.event > event) {
|
||||
if (mid == 0) {
|
||||
if (mid === 0) {
|
||||
break
|
||||
}
|
||||
high = mid - 1
|
||||
@@ -189,6 +195,7 @@ export class SoothPredictor {
|
||||
mid = low
|
||||
}
|
||||
|
||||
// Create a new statistic if not found
|
||||
context.statisticsSize += 1
|
||||
const newMemory = new Array<SoothStatistic>(context.statisticsSize)
|
||||
|
||||
@@ -210,19 +217,29 @@ export class SoothPredictor {
|
||||
return statistic
|
||||
}
|
||||
|
||||
size(id: number) {
|
||||
/**
|
||||
* Returns the size of the statistics for the given context ID.
|
||||
*/
|
||||
public size(id: number): number {
|
||||
const context = this.findContext(id)
|
||||
return context.statisticsSize
|
||||
}
|
||||
|
||||
count(id: number) {
|
||||
/**
|
||||
* Returns the count of observations for the given context ID.
|
||||
*/
|
||||
public count(id: number): number {
|
||||
const context = this.findContext(id)
|
||||
return context.count
|
||||
}
|
||||
|
||||
observe(id: number, event: number) {
|
||||
/**
|
||||
* Observes an event for the given context ID.
|
||||
*/
|
||||
public observe(id: number, event: number): number {
|
||||
const context = this.findContext(id)
|
||||
|
||||
// Handle overflow by halving the counts
|
||||
if (context.count === Number.MAX_SAFE_INTEGER) {
|
||||
context.count = 0
|
||||
for (let i = 0; i < context.statisticsSize; i++) {
|
||||
@@ -240,7 +257,10 @@ export class SoothPredictor {
|
||||
return statistic.count
|
||||
}
|
||||
|
||||
select(id: number, limit: number) {
|
||||
/**
|
||||
* Selects an event based on the limit for the given context ID.
|
||||
*/
|
||||
public select(id: number, limit: number): number {
|
||||
const context = this.findContext(id)
|
||||
if (limit === 0 || limit > context.count) {
|
||||
return this.errorEvent
|
||||
@@ -258,9 +278,13 @@ export class SoothPredictor {
|
||||
return this.errorEvent
|
||||
}
|
||||
|
||||
distribution(id: number) {
|
||||
/**
|
||||
* Returns the probability distribution of events for the given context ID.
|
||||
*/
|
||||
public distribution(id: number): Record<number, number> | null {
|
||||
const context = this.findContext(id)
|
||||
if (!context.statisticsSize) return null
|
||||
|
||||
const total = context.count
|
||||
return context.statistics.reduce(
|
||||
(acc, stat) => {
|
||||
@@ -271,7 +295,10 @@ export class SoothPredictor {
|
||||
)
|
||||
}
|
||||
|
||||
uncertainty(id: number) {
|
||||
/**
|
||||
* Returns the uncertainty (entropy) for the given context ID.
|
||||
*/
|
||||
public uncertainty(id: number): number | null {
|
||||
const context = this.findContext(id)
|
||||
if (!context.statisticsSize) return null
|
||||
|
||||
@@ -286,7 +313,10 @@ export class SoothPredictor {
|
||||
return uncertainty
|
||||
}
|
||||
|
||||
surprise(id: number, event: number) {
|
||||
/**
|
||||
* Returns the surprise for the given event and context ID.
|
||||
*/
|
||||
public surprise(id: number, event: number): number | null {
|
||||
const context = this.findContext(id)
|
||||
if (context.count === 0) {
|
||||
return null
|
||||
@@ -303,7 +333,10 @@ export class SoothPredictor {
|
||||
return Object.is(surpriseValue, -0) ? 0 : surpriseValue
|
||||
}
|
||||
|
||||
frequency(id: number, event: number) {
|
||||
/**
|
||||
* Returns the frequency of the given event for the context ID.
|
||||
*/
|
||||
public frequency(id: number, event: number): number {
|
||||
const context = this.findContext(id)
|
||||
if (context.count == 0) {
|
||||
return 0
|
||||
|
||||
Reference in New Issue
Block a user