diff --git a/alias.go b/alias.go index d355612..ce4c2fc 100644 --- a/alias.go +++ b/alias.go @@ -2,11 +2,12 @@ package slinguist import "strings" -// GetLanguageByAlias returns the language related to the given alias or Otherlanguage otherwise. -func GetLanguageByAlias(alias string) (lang string) { +// GetLanguageByAlias returns the language related to the given alias and ok set to true, +// or Otherlanguage and ok set to false otherwise. +func GetLanguageByAlias(alias string) (lang string, ok bool) { a := strings.Split(alias, `,`)[0] a = strings.ToLower(a) - lang, ok := languagesByAlias[a] + lang, ok = languagesByAlias[a] if !ok { lang = OtherLanguage } diff --git a/alias_test.go b/alias_test.go index 9bdbb43..3612e37 100644 --- a/alias_test.go +++ b/alias_test.go @@ -6,21 +6,23 @@ func (s *TSuite) TestGetLanguageByAlias(c *C) { tests := []struct { alias string expectedLang string + expectedOk bool }{ - {alias: "BestLanguageEver", expectedLang: OtherLanguage}, - {alias: "aspx-vb", expectedLang: "ASP"}, - {alias: "C++", expectedLang: "C++"}, - {alias: "c++", expectedLang: "C++"}, - {alias: "objc", expectedLang: "Objective-C"}, - {alias: "golang", expectedLang: "Go"}, - {alias: "GOLANG", expectedLang: "Go"}, - {alias: "bsdmake", expectedLang: "Makefile"}, - {alias: "xhTmL", expectedLang: "HTML"}, - {alias: "python", expectedLang: "Python"}, + {alias: "BestLanguageEver", expectedLang: OtherLanguage, expectedOk: false}, + {alias: "aspx-vb", expectedLang: "ASP", expectedOk: true}, + {alias: "C++", expectedLang: "C++", expectedOk: true}, + {alias: "c++", expectedLang: "C++", expectedOk: true}, + {alias: "objc", expectedLang: "Objective-C", expectedOk: true}, + {alias: "golang", expectedLang: "Go", expectedOk: true}, + {alias: "GOLANG", expectedLang: "Go", expectedOk: true}, + {alias: "bsdmake", expectedLang: "Makefile", expectedOk: true}, + {alias: "xhTmL", expectedLang: "HTML", expectedOk: true}, + {alias: "python", expectedLang: "Python", expectedOk: true}, } for _, test := range tests { - lang := GetLanguageByAlias(test.alias) + lang, ok := GetLanguageByAlias(test.alias) c.Assert(lang, Equals, test.expectedLang) + c.Assert(ok, Equals, test.expectedOk) } } diff --git a/classifier.go b/classifier.go new file mode 100644 index 0000000..db3117e --- /dev/null +++ b/classifier.go @@ -0,0 +1,100 @@ +package slinguist + +import ( + "math" + + "gopkg.in/src-d/simple-linguist.v1/internal/tokenizer" +) + +// GetLanguageByClassifier takes in a content and a list of candidates, and apply the classifier's Classify method to +// get the most probably language. If classifier is null then DefaultClassfier will be used. +func GetLanguageByClassifier(content []byte, candidates []string, classifier Classifier) string { + if classifier == nil { + classifier = DefaultClassifier + } + + scores := classifier.Classify(content, candidates) + if len(scores) == 0 { + return OtherLanguage + } + + return getLangugeHigherScore(scores) +} + +func getLangugeHigherScore(scores map[string]float64) string { + var language string + higher := -math.MaxFloat64 + for lang, score := range scores { + if higher < score { + language = lang + higher = score + } + } + + return language +} + +// Classifier is the interface that contains the method Classify which is in charge to assign scores to the possibles candidates. +// The scores must order the candidates so as the highest score be the most probably language of the content. +type Classifier interface { + Classify(content []byte, candidates []string) map[string]float64 +} + +type classifier struct { + languagesLogProbabilities map[string]float64 + tokensLogProbabilities map[string]map[string]float64 + tokensTotal float64 +} + +func (c *classifier) Classify(content []byte, candidates []string) map[string]float64 { + if len(content) == 0 { + return nil + } + + var languages []string + if len(candidates) == 0 { + languages = c.knownLangs() + } else { + languages = make([]string, 0, len(candidates)) + for _, candidate := range candidates { + if lang, ok := GetLanguageByAlias(candidate); ok { + languages = append(languages, lang) + } + } + } + + tokens := tokenizer.Tokenize(content) + scores := make(map[string]float64, len(languages)) + for _, language := range languages { + scores[language] = c.tokensLogProbability(tokens, language) + c.languagesLogProbabilities[language] + } + + return scores +} + +func (c *classifier) knownLangs() []string { + langs := make([]string, 0, len(c.languagesLogProbabilities)) + for lang := range c.languagesLogProbabilities { + langs = append(langs, lang) + } + + return langs +} + +func (c *classifier) tokensLogProbability(tokens []string, language string) float64 { + var sum float64 + for _, token := range tokens { + sum += c.tokenProbability(token, language) + } + + return sum +} + +func (c *classifier) tokenProbability(token, language string) float64 { + tokenProb, ok := c.tokensLogProbabilities[language][token] + if !ok { + tokenProb = math.Log(1.000000 / c.tokensTotal) + } + + return tokenProb +} diff --git a/classifier_test.go b/classifier_test.go new file mode 100644 index 0000000..f0787b3 --- /dev/null +++ b/classifier_test.go @@ -0,0 +1,32 @@ +package slinguist + +import ( + "io/ioutil" + "path/filepath" + + . "gopkg.in/check.v1" +) + +func (s *TSuite) TestGetLanguageByClassifier(c *C) { + const samples = `.linguist/samples/` + test := []struct { + filename string + candidates []string + expectedLang string + }{ + {filename: filepath.Join(samples, "C/blob.c"), candidates: []string{"python", "ruby", "c", "c++"}, expectedLang: "C"}, + {filename: filepath.Join(samples, "C/blob.c"), candidates: nil, expectedLang: "C"}, + {filename: filepath.Join(samples, "C/main.c"), candidates: nil, expectedLang: "C"}, + {filename: filepath.Join(samples, "C/blob.c"), candidates: []string{"python", "ruby", "c++"}, expectedLang: "C++"}, + {filename: filepath.Join(samples, "C/blob.c"), candidates: []string{"ruby"}, expectedLang: "Ruby"}, + {filename: filepath.Join(samples, "Python/django-models-base.py"), candidates: []string{"python", "ruby", "c", "c++"}, expectedLang: "Python"}, + {filename: filepath.Join(samples, "Python/django-models-base.py"), candidates: nil, expectedLang: "Python"}, + } + + for _, test := range test { + content, err := ioutil.ReadFile(test.filename) + c.Assert(err, Equals, nil) + lang := GetLanguageByClassifier(content, test.candidates, nil) + c.Assert(lang, Equals, test.expectedLang) + } +} diff --git a/common.go b/common.go index 5935ad3..2262e7d 100644 --- a/common.go +++ b/common.go @@ -52,7 +52,11 @@ func GetLanguage(filename string, content []byte) string { return lang } - lang, _ := GetLanguageByContent(filename, content) + if lang, safe := GetLanguageByContent(filename, content); safe { + return lang + } + + lang := GetLanguageByClassifier(content, nil, nil) return lang } diff --git a/internal/code-generator/generator/generator.go b/internal/code-generator/generator/generator.go index 73b40a4..c9a06fe 100644 --- a/internal/code-generator/generator/generator.go +++ b/internal/code-generator/generator/generator.go @@ -21,11 +21,7 @@ func FromFile(fileToParse, outPath, tmplPath, tmplName, commit string, generate return err } - if err := formatedWrite(outPath, source); err != nil { - return err - } - - return nil + return formatedWrite(outPath, source) } func formatedWrite(outPath string, source []byte) error { diff --git a/internal/code-generator/generator/samplesfreq.go b/internal/code-generator/generator/samplesfreq.go index 4ff2939..fa5542b 100644 --- a/internal/code-generator/generator/samplesfreq.go +++ b/internal/code-generator/generator/samplesfreq.go @@ -16,6 +16,8 @@ import ( "gopkg.in/src-d/simple-linguist.v1/internal/tokenizer" ) +const samplesSubDir = "filenames" + type samplesFrequencies struct { LanguageTotal int `json:"language_total,omitempty"` Languages map[string]int `json:"languages,omitempty"` @@ -37,11 +39,7 @@ func Frequencies(samplesDir, frequenciesTmplPath, frequenciesTmplName, commit, o return err } - if err := formatedWrite(outPath, buf.Bytes()); err != nil { - return err - } - - return nil + return formatedWrite(outPath, buf.Bytes()) } func getFrequencies(samplesDir string) (*samplesFrequencies, error) { @@ -98,8 +96,6 @@ func getFrequencies(samplesDir string) (*samplesFrequencies, error) { } func getSamples(samplesDir string, langDir os.FileInfo) ([]string, error) { - const subDir = "filenames" - samples := []string{} path := filepath.Join(samplesDir, langDir.Name()) entries, err := ioutil.ReadDir(path) @@ -112,7 +108,7 @@ func getSamples(samplesDir string, langDir os.FileInfo) ([]string, error) { samples = append(samples, filepath.Join(path, entry.Name())) } - if entry.IsDir() && entry.Name() == subDir { + if entry.IsDir() && entry.Name() == samplesSubDir { subSamples, err := getSubSamples(samplesDir, langDir.Name(), entry) if err != nil { return nil, err diff --git a/modeline.go b/modeline.go index af88822..16669cf 100644 --- a/modeline.go +++ b/modeline.go @@ -10,10 +10,9 @@ import ( func GetLanguageByModeline(content []byte) (lang string, safe bool) { headFoot := getHeaderAndFooter(content) for _, getLang := range modelinesFunc { - lang = getLang(headFoot) - safe = lang != OtherLanguage + lang, safe = getLang(headFoot) if safe { - return + break } } @@ -23,7 +22,7 @@ func GetLanguageByModeline(content []byte) (lang string, safe bool) { func getHeaderAndFooter(content []byte) []byte { const ( searchScope = 5 - eol = `\n` + eol = "\n" ) if bytes.Count(content, []byte(eol)) < 2*searchScope { @@ -37,7 +36,7 @@ func getHeaderAndFooter(content []byte) []byte { return bytes.Join(headerAndFooter, []byte(eol)) } -var modelinesFunc = []func(content []byte) string{ +var modelinesFunc = []func(content []byte) (string, bool){ GetLanguageByEmacsModeline, GetLanguageByVimModeline, } @@ -50,11 +49,11 @@ var ( ) // GetLanguageByEmacsModeline detecs if the content has a emacs modeline and try to get a -// language basing on alias. If couldn't retrieve a valid language, it returns OtherLanguage. -func GetLanguageByEmacsModeline(content []byte) (lang string) { +// language basing on alias. If couldn't retrieve a valid language, it returns OtherLanguage and false. +func GetLanguageByEmacsModeline(content []byte) (string, bool) { matched := reEmacsModeline.FindAllSubmatch(content, -1) if matched == nil { - return OtherLanguage + return OtherLanguage, false } // only take the last matched line, discard previous lines @@ -67,23 +66,22 @@ func GetLanguageByEmacsModeline(content []byte) (lang string) { alias = string(lastLineMatched) } - lang = GetLanguageByAlias(alias) - return + return GetLanguageByAlias(alias) } // GetLanguageByVimModeline detecs if the content has a vim modeline and try to get a -// language basing on alias. If couldn't retrieve a valid language, it returns OtherLanguage. -func GetLanguageByVimModeline(content []byte) (lang string) { +// language basing on alias. If couldn't retrieve a valid language, it returns OtherLanguage and false. +func GetLanguageByVimModeline(content []byte) (string, bool) { matched := reVimModeline.FindAllSubmatch(content, -1) if matched == nil { - return OtherLanguage + return OtherLanguage, false } // only take the last matched line, discard previous lines lastLineMatched := matched[len(matched)-1][1] matchedAlias := reVimLang.FindAllSubmatch(lastLineMatched, -1) if matchedAlias == nil { - return OtherLanguage + return OtherLanguage, false } alias := string(matchedAlias[0][1]) @@ -100,6 +98,5 @@ func GetLanguageByVimModeline(content []byte) (lang string) { } } - lang = GetLanguageByAlias(alias) - return + return GetLanguageByAlias(alias) } diff --git a/modeline_test.go b/modeline_test.go index d117eff..4cbda8d 100644 --- a/modeline_test.go +++ b/modeline_test.go @@ -9,6 +9,7 @@ import ( const ( modelinesDir = ".linguist/test/fixtures/Data/Modelines" + samplesDir = ".linguist/samples" ) func (s *TSuite) TestGetLanguageByModeline(c *C) { @@ -18,42 +19,43 @@ func (s *TSuite) TestGetLanguageByModeline(c *C) { expectedSafe bool }{ // Emacs - {filename: "example_smalltalk.md", expectedLang: "Smalltalk", expectedSafe: true}, - {filename: "fundamentalEmacs.c", expectedLang: "Text", expectedSafe: true}, - {filename: "iamphp.inc", expectedLang: "PHP", expectedSafe: true}, - {filename: "seeplusplusEmacs1", expectedLang: "C++", expectedSafe: true}, - {filename: "seeplusplusEmacs2", expectedLang: "C++", expectedSafe: true}, - {filename: "seeplusplusEmacs3", expectedLang: "C++", expectedSafe: true}, - {filename: "seeplusplusEmacs4", expectedLang: "C++", expectedSafe: true}, - {filename: "seeplusplusEmacs5", expectedLang: "C++", expectedSafe: true}, - {filename: "seeplusplusEmacs6", expectedLang: "C++", expectedSafe: true}, - {filename: "seeplusplusEmacs7", expectedLang: "C++", expectedSafe: true}, - {filename: "seeplusplusEmacs9", expectedLang: "C++", expectedSafe: true}, - {filename: "seeplusplusEmacs10", expectedLang: "C++", expectedSafe: true}, - {filename: "seeplusplusEmacs11", expectedLang: "C++", expectedSafe: true}, - {filename: "seeplusplusEmacs12", expectedLang: "C++", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "example_smalltalk.md"), expectedLang: "Smalltalk", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "fundamentalEmacs.c"), expectedLang: "Text", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "iamphp.inc"), expectedLang: "PHP", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "seeplusplusEmacs1"), expectedLang: "C++", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "seeplusplusEmacs2"), expectedLang: "C++", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "seeplusplusEmacs3"), expectedLang: "C++", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "seeplusplusEmacs4"), expectedLang: "C++", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "seeplusplusEmacs5"), expectedLang: "C++", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "seeplusplusEmacs6"), expectedLang: "C++", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "seeplusplusEmacs7"), expectedLang: "C++", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "seeplusplusEmacs9"), expectedLang: "C++", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "seeplusplusEmacs10"), expectedLang: "C++", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "seeplusplusEmacs11"), expectedLang: "C++", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "seeplusplusEmacs12"), expectedLang: "C++", expectedSafe: true}, // Vim - {filename: "seeplusplus", expectedLang: "C++", expectedSafe: true}, - {filename: "iamjs.pl", expectedLang: "JavaScript", expectedSafe: true}, - {filename: "iamjs2.pl", expectedLang: "JavaScript", expectedSafe: true}, - {filename: "not_perl.pl", expectedLang: "Prolog", expectedSafe: true}, - {filename: "ruby", expectedLang: "Ruby", expectedSafe: true}, - {filename: "ruby2", expectedLang: "Ruby", expectedSafe: true}, - {filename: "ruby3", expectedLang: "Ruby", expectedSafe: true}, - {filename: "ruby4", expectedLang: "Ruby", expectedSafe: true}, - {filename: "ruby5", expectedLang: "Ruby", expectedSafe: true}, - {filename: "ruby6", expectedLang: "Ruby", expectedSafe: true}, - {filename: "ruby7", expectedLang: "Ruby", expectedSafe: true}, - {filename: "ruby8", expectedLang: "Ruby", expectedSafe: true}, - {filename: "ruby9", expectedLang: "Ruby", expectedSafe: true}, - {filename: "ruby10", expectedLang: "Ruby", expectedSafe: true}, - {filename: "ruby11", expectedLang: "Ruby", expectedSafe: true}, - {filename: "ruby12", expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "seeplusplus"), expectedLang: "C++", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "iamjs.pl"), expectedLang: "JavaScript", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "iamjs2.pl"), expectedLang: "JavaScript", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "not_perl.pl"), expectedLang: "Prolog", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "ruby"), expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "ruby2"), expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "ruby3"), expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "ruby4"), expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "ruby5"), expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "ruby6"), expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "ruby7"), expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "ruby8"), expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "ruby9"), expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "ruby10"), expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "ruby11"), expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(modelinesDir, "ruby12"), expectedLang: "Ruby", expectedSafe: true}, + {filename: filepath.Join(samplesDir, "C/main.c"), expectedLang: OtherLanguage, expectedSafe: false}, } for _, test := range linguistTests { - content, err := ioutil.ReadFile(filepath.Join(modelinesDir, test.filename)) + content, err := ioutil.ReadFile(test.filename) c.Assert(err, Equals, nil) lang, safe := GetLanguageByModeline(content) @@ -62,8 +64,9 @@ func (s *TSuite) TestGetLanguageByModeline(c *C) { } const ( - wrongVim = `# vim: set syntax=ruby ft =python filetype=perl :` - rightVim = `/* vim: set syntax=python ft =python filetype=python */` + wrongVim = `# vim: set syntax=ruby ft =python filetype=perl :` + rightVim = `/* vim: set syntax=python ft =python filetype=python */` + noLangVim = `/* vim: set shiftwidth=4 softtabstop=0 cindent cinoptions={1s: */` ) tests := []struct { @@ -73,6 +76,7 @@ func (s *TSuite) TestGetLanguageByModeline(c *C) { }{ {content: []byte(wrongVim), expectedLang: OtherLanguage, expectedSafe: false}, {content: []byte(rightVim), expectedLang: "Python", expectedSafe: true}, + {content: []byte(noLangVim), expectedLang: OtherLanguage, expectedSafe: false}, } for _, test := range tests {