%
% Code for a Self-Organizing Map (SOM) in Matlab.
% Simulation takes noisy retinal images of letters and learns
% letter representations.
% The goal is to simulate, in a coarse manner, the first steps of the visual
% cortex. The first processing steps are hard coded, but the last one is
% based on SOM learning.
%
% As a model of the brain, the model features the use of 
% cortical columns and receptive fields
%
% (c) Frdric Dandurand, 2011

% clean up figures and memory
clear all
close all

addpath('../library'); % include SOM code library

rand('seed', 1)  % use the same seed, to make results replication possible

learningRate = 0.5;  % SOM learning rate: magnitude of adjustment of input weights towards the input values
factorVector = [1 0.4 0.2 0.1];  % defines the magnitude of the contribution of neighbours at street block distances 1, 2 and 3

useDistance = true;  % boolean value indicating the method used to select the
% best matching unit (SOM winner). Set to "true" to select winner 
% with minimal distance to input vector, and false to select the most active unit.

debug = false; % set to false for fast vector-based processing, and set to 
% true to also compute and compare with slower item-based processing 
% using the more readable for loops and compare with faster vector based 
% computations. In debug mode, the program is also more verbose.

maxIter = 1000;     % number of SOM learning interations (epochs)
mapSize = [8 8];  % Dimension of the SOM map onto which inputs are projected
temperatureFactor = 0.0; % temperature (if > 0) determines an exponential decay of learning

receptiveFieldSize = [101 101];   % size of the receptive field that a unit on the SOM map can see on the input layer. 
% Set to INF for a standard SOM in which units on the map see all inputs

% initial connection weights are set within the following range:
%[meanInitialWeight - weightRange] and [meanInitialWeight + weightRange]
weightRange = 0.2;
meanInitialWeight = 0.2; 

noiseLevel = 0.1; % add uniform noise to inputs with range [-NoiseLevel,+NoiseLevel].

% Accept only odd values for receptive field sizes
if (rem(receptiveFieldSize, [2 2]) == 0)
    error('Receptive field size must be composed of odd (i.e., not even) values')
end

% directory where to store results
root = '../../results/RetinalImageToLetters/';
if (useDistance)
    root = [root, 'useMinDistance/'];
else
    root = [root, 'useMaxActivation/'];
end

mkdir(root);
resFile = [root, 'TrainResults.txt'];

lettersData  = getLettersData('./letters.used/');  % loading letter data files from letters.used directory
lettersData = computeMaps(lettersData);  % compute other maps (angles, etc)

inputMap = size(lettersData(1).maps,2);  % use the highest map as input (here, angles)

dim1 = size(lettersData(1).maps(inputMap).data, 1); % number of rows of lower map
dim2 = size(lettersData(1).maps(inputMap).data, 2); % number of cols of lower map
CORTICAL_COLS = size(lettersData(1).maps(inputMap).data, 3);  % number of cortical columns in each input, here, the number of possible angles (see

dim3 = mapSize(1);  % number of rows in the upper (SOM) map
dim4 = mapSize(2);  % number of cols in the upper (SOM) map

% location of active connections are indicated by values of one (1), and units not connected = 0
weightsMask = getWeightsMask(dim1, dim2, CORTICAL_COLS, dim3, dim4, receptiveFieldSize);

% set random connection weights between the units connected on the
% input and the SOM maps, weights vary randomly between
% [meanInitialWeight-weightRange] and [meanInitialWeight+weightRange]  
weights = meanInitialWeight + 2 * weightRange * (rand(size(weightsMask)) - 1);
weights = weights .* weightsMask;

% storing general information about the simulation
trainData.inputMap = inputMap;
trainData.mapSize = mapSize;
trainData.useDistance = useDistance;

iterId = 1;
solutionFound = false;

while (iterId <= maxIter && solutionFound == false)  % run until fixed epoch or until every pattern is coded on a different unit
    % store some information about the current iteration
    trainData.trainingResults(iterId).winnerVal = 0;  % field that keeps track of the total distance between inputs and winner connections for all letters
    trainData.trainingResults(iterId).worstLetter = '';  % field that keeps track of the letter that is least well-coded by the SOM map
    if (useDistance)
        trainData.trainingResults(iterId).worstValue = -inf;  % initialize worst distance as -inf to make sure all letter distances are larger
    else
        trainData.trainingResults(iterId).worstValue = +inf;  % initialize worst activation as +inf to make sure all letter activations are smaller
    end

    patternsToProcess = randperm(length(lettersData));     %randomize order of patterns
    
    winners = zeros(mapSize);  % stores positions of best matching units (BMU) on the map 
    
    for patNo = 1:length(lettersData)
        patId = patternsToProcess(patNo);  % retrieve next randomly ordered pattern id
        pat = lettersData(patId).maps(inputMap).data;  % use pattern of activations of the input map 
        
        % add noise to input values (only where letters are located)
        nonEmptyIndexes = find(pat > 0);
        pat(nonEmptyIndexes) = pat(nonEmptyIndexes) + noiseLevel * rand(size(nonEmptyIndexes));

        % Let the winner learn
        if (useDistance)
            % compute distances matrix, minimal distance, and coordinates 
            % of unit with minimal distance to the input vector
            [distances, minDistVal, winnerRow, winnerCol] = computeSOMDistances(pat, weights, weightsMask, debug);

            % if the current letter is worse than the previously found one
            % i.e., with a larger distance, then replace the worst value
            % and corresponding label
            if(minDistVal > trainData.trainingResults(iterId).worstValue)
                trainData.trainingResults(iterId).worstValue = minDistVal;
                trainData.trainingResults(iterId).worstLetter = lettersData(patId).label;
            end
            val = minDistVal;
        else
            % verify that the unit with the smallest distance is also
            % the most active
            [activations maxActVal,winnerRow,winnerCol] = computeSOMActivations(pat, weights, weightsMask, debug);

            % if the current letter is worse than the previously found one
            % i.e., with a smaller activation, then replace the worst value
            % and corresponding label
            if(maxActVal < trainData.trainingResults(iterId).worstValue)
                trainData.trainingResults(iterId).worstValue = maxActVal;
                trainData.trainingResults(iterId).worstLetter = lettersData(patId).label;
            end
            val = maxActVal;
        end
        
        winners(winnerRow,winnerCol) = 1;  % indicate that the winner unit is a BMU

        % proceed to a SOM update of the weights based on the winning unit
        updatedWeights = computeUpdatedWeights(weights,weightsMask,pat,...
            dim1,dim2,CORTICAL_COLS,dim3,dim4,...
            winnerRow,winnerCol,iterId,learningRate,temperatureFactor,factorVector,debug);

        % store information about the BMU, and update value (activation or
        % distance) of the winner
        trainData.trainingResults(iterId).winnerLocation(patId).id = lettersData(patId).id;
        trainData.trainingResults(iterId).winnerLocation(patId).row = winnerRow;
        trainData.trainingResults(iterId).winnerLocation(patId).col = winnerCol;
        trainData.trainingResults(iterId).winnerVal = trainData.trainingResults(iterId).winnerVal + val;

        % use the updated weights in the next iteration
        weights = updatedWeights;
        
        % storing the SOM activations
        [activations maxActVal,winnerRow,winnerCol] = computeSOMActivations(lettersData(patId).maps(inputMap).data, weights, weightsMask, debug);
        lettersData(patId).maps(inputMap+1).data = activations;
        lettersData(patId).maps(inputMap+1).label = 'SOM map'; 
    end

    % computing number of different SOM units that were winners (BMU)
    trainData.trainingResults(iterId).bestMatchingUnitsCount = sum(sum(winners));
    
    % saving results
    save([root, 'MapData.mat'], 'lettersData', 'trainData', 'weights');

    % generating information string, display it and save it
    statsStr = dispStats(trainData, iterId);
    disp(statsStr)
    fid = fopen(resFile, 'at');  % opening the file, as text, overwriting content
    fprintf(fid, '%s', statsStr);   % print to file as string
    fclose(fid);  % close the file

    % Success if each pattern is best coded onto a single distinct unit
    if (trainData.trainingResults(iterId).bestMatchingUnitsCount == length(lettersData))
        solutionFound = true;
        disp('Success!')
    end
    iterId = iterId+1;
end


% finally, output to file maps for all letters
disp('Outputting map plots to image files...')
for patId = 1:length(lettersData)
    plotFileName = [root, lettersData(patId).label];
    
    % replace txt extension by png
    plotFileName(size(plotFileName,2)-2) = 'p';
    plotFileName(size(plotFileName,2)-1) = 'n';
    plotFileName(size(plotFileName,2)) = 'g';
    disp(plotFileName)
    viewmap(lettersData(patId).maps, ones(1,inputMap+1));
    print('-f1', '-dpng',  plotFileName);
end
close all
