Latent Semantic Indexing

March 7, 2011

Here is a toy program that demonstrates LSI using the SVD contained within the COLT linear algebra library. The method is taken from here.

import cern.colt.matrix.linalg.*;
import cern.colt.matrix.*;
import cern.colt.matrix.impl.*;
import java.util.*;
import java.io.*;

public class SVDText
{
 public static void main (String args[])
 {
    try
    {
       DenseDoubleMatrix2D source = new DenseDoubleMatrix2D(3,11);

       Scanner sc = new Scanner(new File("input.txt"));
       for (int row=0;row<11;row++)
       {
          for (int col=0;col<3;col++)
          {
             float value = sc.nextFloat();
             source.setQuick(col,row,value*1.0);
          }
       }

       DenseDoubleMatrix2D query = new DenseDoubleMatrix2D(1,11);

       sc = new Scanner(new File("query.txt"));
       for (int col=0;col<11;col++)
       {
          long value = sc.nextLong();
          query.setQuick(0,col,value*1.0);
       }

       Algebra alg = new Algebra();

       SingularValueDecomposition svd = new SingularValueDecomposition(source);

       // reduce rank
       DoubleMatrix2D reducedU = alg.subMatrix(alg.transpose(svd.getU()),0,1,0,10);
       DoubleMatrix2D reducedS = alg.subMatrix(alg.transpose(svd.getS()),0,1,0,1);
       DoubleMatrix2D reducedV = alg.subMatrix(alg.transpose(svd.getV()),0,1,0,2);

       DoubleMatrix2D reducedVt = alg.transpose(reducedV);

       DoubleMatrix2D inverseS = alg.pow(reducedS,-1);

       DoubleMatrix2D q1 = alg.mult(inverseS,reducedU);
       System.out.println("q1 = " + q1);
       DoubleMatrix1D queryVector = alg.mult(q1,alg.transpose(query)).viewRow(0);

       System.out.println("query vector " + queryVector);
       DoubleMatrix1D d1 = alg.subMatrix(reducedVt,0,0,0,1).viewColumn(0);
       System.out.println("d1 = " + d1);

       DoubleMatrix1D d2 = alg.subMatrix(reducedVt,1,1,0,1).viewColumn(0);
       System.out.println("d2 = " + d2);

       DoubleMatrix1D d3 = alg.subMatrix(reducedVt,2,2,0,1).viewColumn(0);

       System.out.println("Doc 1 measure = " + queryVector.zDotProduct(d1) / ((alg.norm1(queryVector)*alg.norm1(d1))));

       System.out.println("Doc 2 measure = " + queryVector.zDotProduct(d2) / ((alg.norm1(queryVector)*alg.norm1(d2))));

       System.out.println("Doc 3 measure = " + queryVector.zDotProduct(d3) / ((alg.norm1(queryVector)*alg.norm1(d3))));
   }
   catch (Exception e)
   {
      e.printStackTrace();
   }
 }
}

Tags: ,

Comments are closed.