-
-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of SVM, Decision Tree, and Random Forest algorithms #211
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThe recent changes introduce a comprehensive implementation of Support Vector Machines (SVMs), Decision Trees, and Random Forests for classification tasks within the Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant SVM
participant Data
User->>Data: Provide training data
Data->>SVM: Send data points
SVM->>SVM: Train model using training algorithm
SVM->>User: Return trained model
User->>Data: Provide test data
Data->>SVM: Send test points
SVM->>User: Return predictions
sequenceDiagram
participant User
participant DecisionTree
participant Dataset
User->>Dataset: Create a new dataset
Dataset->>User: Return dataset instance
User->>Dataset: Add samples
Dataset->>DecisionTree: Train on dataset
DecisionTree->>User: Return trained tree
User->>Dataset: Provide test features
DecisionTree->>User: Return predictions
sequenceDiagram
participant User
participant RandomForest
participant Dataset
User->>Dataset: Create a new dataset
Dataset->>User: Return dataset instance
User->>RandomForest: Train with dataset
RandomForest->>User: Return trained forest
User->>Dataset: Provide test features
RandomForest->>User: Return predictions
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- ml/svm.v (1 hunks)
- ml/svm_test.v (1 hunks)
Additional comments not posted (16)
ml/svm_test.v (7)
5-12
: Test for polynomial kernel looks good!The test correctly verifies the polynomial kernel function.
14-22
: Test for RBF kernel looks good!The test correctly verifies the RBF kernel function with a tolerance check.
24-38
: Test for SVM training looks good!The test correctly verifies the SVM training process and checks predictions against expected labels.
40-55
: Test for SVM prediction looks good!The test correctly verifies the SVM prediction process and checks predictions against expected labels.
57-72
: Test for multiclass SVM training looks good!The test correctly verifies the multiclass SVM training process and checks predictions against expected labels.
74-91
: Test for multiclass SVM prediction looks good!The test correctly verifies the multiclass SVM prediction process and checks predictions against expected labels.
93-105
: Test for multiple kernels looks good!The test correctly verifies multiple kernel functions by running the training and prediction tests for each kernel.
ml/svm.v (9)
6-12
: SVMConfig structure looks good!The structure correctly defines configuration parameters for the SVM training process with appropriate default values.
14-18
: DataPoint structure looks good!The structure correctly represents individual data points with features and class labels.
20-27
: SVMModel structure looks good!The structure correctly contains the results of the SVM training process, including support vectors, alphas, bias term, kernel function, and configuration parameters.
31-46
: Kernel functions look good!The linear, polynomial, and RBF kernel functions are correctly implemented, providing necessary kernel computations for the SVM.
48-62
: Utility functions look good!The dot product and vector subtraction utility functions are correctly implemented, providing necessary utility operations for the SVM.
64-154
: SVM training function looks good!The
train_svm
function is correctly implemented, providing necessary training operations for binary SVM.
174-205
: Multiclass SVM training function looks good!The
train_multiclass_svm
function is correctly implemented, providing necessary training operations for multiclass SVM.
156-166
: SVM prediction functions look good!The
predict_raw
andpredict
functions are correctly implemented, providing necessary prediction operations for binary SVM.
207-229
: Multiclass SVM prediction function looks good!The
predict_multiclass
function is correctly implemented, providing necessary prediction operations for multiclass SVM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- ml/svm.v (1 hunks)
- ml/svm_test.v (1 hunks)
Files skipped from review due to trivial changes (1)
- ml/svm.v
Files skipped from review as they are similar to previous changes (1)
- ml/svm_test.v
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- ml/svm.v (1 hunks)
- ml/svm_test.v (1 hunks)
Additional comments not posted (26)
ml/svm_test.v (11)
5-10
: LGTM!The test function correctly validates the
vector_dot
function.
12-17
: LGTM!The test function correctly validates the
vector_subtract
function.
19-24
: LGTM!The test function correctly validates the
linear_kernel
function.
26-33
: LGTM!The test function correctly validates the
polynomial_kernel
function.
35-43
: LGTM!The test function correctly validates the
rbf_kernel
function.
45-50
: LGTM!The test function correctly validates the
SVM.new
constructor.
52-66
: LGTM!The test function correctly validates the
train
andpredict
methods of theSVM
struct.
68-82
: LGTM!The test function correctly validates the
train_svm
function.
84-97
: LGTM!The test function correctly validates the
predict_raw
function.
99-113
: LGTM!The test function correctly validates the
predict
function.
115-127
: LGTM!The main function correctly runs all the test functions and prints a success message.
ml/svm.v (15)
6-12
: LGTM!The
SVMConfig
struct correctly defines the necessary parameters for SVM training.
14-18
: LGTM!The
DataPoint
struct correctly defines the necessary fields for a data point.
20-27
: LGTM!The
SVMModel
struct correctly defines the necessary fields for an SVM model.
29-34
: LGTM!The
SVM
struct correctly defines the necessary fields for an SVM instance.
36-43
: LGTM!The
KernelFunction
type alias andvector_dot
function correctly define and implement the necessary operations.
46-52
: LGTM!The
vector_subtract
function correctly implements vector subtraction.
54-56
: LGTM!The
linear_kernel
function correctly implements the linear kernel.
58-62
: LGTM!The
polynomial_kernel
function correctly implements the polynomial kernel.
64-68
: LGTM!The
rbf_kernel
function correctly implements the RBF kernel.
71-76
: LGTM!The
SVM.new
constructor correctly initializes an SVM instance.
78-80
: LGTM!The
train
method correctly trains the SVM instance.
82-84
: LGTM!The
predict
method correctly predicts the class label for a given input.
86-175
: LGTM!The
train_svm
function correctly implements the SVM training logic.
178-184
: LGTM!The
predict_raw
function correctly implements the raw prediction logic for an SVM model.
186-188
: LGTM!The
predict
function correctly implements the prediction logic for an SVM model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- ml/svm.v (1 hunks)
- ml/svm_test.v (1 hunks)
Files skipped from review as they are similar to previous changes (2)
- ml/svm.v
- ml/svm_test.v
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Outside diff range, codebase verification and nitpick comments (9)
ml/svm_test.v (9)
6-13
: Remove print statement from unit test.The print statement is not necessary in unit tests and can be removed to keep the test output clean.
- println('Linear kernel test passed.')
15-24
: Remove print statement from unit test.The print statement is not necessary in unit tests and can be removed to keep the test output clean.
- println('Polynomial kernel test passed.')
26-36
: Remove print statement from unit test.The print statement is not necessary in unit tests and can be removed to keep the test output clean.
- println('RBF kernel test passed.')
38-45
: Remove print statement from unit test.The print statement is not necessary in unit tests and can be removed to keep the test output clean.
- println('Quadratic kernel test passed.')
47-56
: Remove print statement from unit test.The print statement is not necessary in unit tests and can be removed to keep the test output clean.
- println('Custom kernel test passed.')
58-65
: Remove print statement from unit test.The print statement is not necessary in unit tests and can be removed to keep the test output clean.
- println('Dot product test passed.')
67-74
: Remove print statement from unit test.The print statement is not necessary in unit tests and can be removed to keep the test output clean.
- println('Vector subtract test passed.')
76-117
: Remove print statement and add more assertions in unit test.The print statement is not necessary in unit tests and can be removed to keep the test output clean. Additionally, consider adding more assertions to validate the model parameters.
- println('SVM model training and prediction test passed.') + assert model.support_vectors.len > 0 + assert model.alphas.len == data.len + assert model.b != 0
119-168
: Remove print statement and add more assertions in unit test.The print statement is not necessary in unit tests and can be removed to keep the test output clean. Additionally, consider adding more assertions to validate the model parameters.
- println('Multiclass SVM model training and prediction test passed.') + assert multiclass_model.models.len > 0 + for model_row in multiclass_model.models { + assert model_row.len > 0 + }
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- ml/svm.v (1 hunks)
- ml/svm_test.v (1 hunks)
Additional comments not posted (12)
ml/svm_test.v (1)
170-180
: LGTM!The function correctly calls all the test functions.
ml/svm.v (11)
16-22
: LGTM!The
KernelType
enum correctly defines the types of kernel functions.
24-24
: LGTM!The
KernelFunction
type alias correctly defines the type for kernel functions.
26-30
: LGTM!The
DataPoint
structure correctly defines a data point with features and a class label.
32-39
: LGTM!The
SVMModel
structure correctly defines the SVM model with support vectors, alphas, bias, kernel function, and configuration.
41-44
: LGTM!The
MulticlassSVM
structure correctly defines the multiclass SVM model with a matrix of SVM models.
46-72
: LGTM!The kernel functions correctly implement the linear, polynomial, RBF, quadratic, and custom kernels.
74-88
: LGTM!The utility functions correctly implement the dot product and vector subtraction.
180-185
: LGTM!The
compute_l_h
function correctly computes the L and H values for the SMO algorithm.
188-194
: LGTM!The
predict_raw
function correctly computes the raw prediction value for a given input.
196-198
: LGTM!The
predict
function correctly computes the predicted class label for a given input.
231-253
: LGTM!The
predict_multiclass
function correctly computes the predicted class label for a given input using the multiclass SVM model.
pub fn train_svm(data []DataPoint, config SVMConfig) &SVMModel { | ||
kernel := match config.kernel_type { | ||
.linear { linear_kernel } | ||
.polynomial { polynomial_kernel(config.kernel_param) } | ||
.rbf { rbf_kernel(config.kernel_param) } | ||
.quadratic { quadratic_kernel } | ||
.custom { custom_kernel } | ||
} | ||
|
||
mut model := &SVMModel{ | ||
config: config | ||
kernel: kernel | ||
} | ||
|
||
mut alphas := []f64{len: data.len, init: 0.0} | ||
mut b := 0.0 | ||
|
||
for _ in 0 .. config.max_iterations { | ||
mut alpha_pairs_changed := 0 | ||
|
||
for i := 0; i < data.len; i++ { | ||
ei := predict_raw(model, data[i].x) - f64(data[i].y) | ||
if (data[i].y * ei < -config.tolerance && alphas[i] < config.c) | ||
|| (data[i].y * ei > config.tolerance && alphas[i] > 0) { | ||
mut j := rand.intn(data.len - 1) or { 0 } | ||
if j >= i { | ||
j += 1 | ||
} | ||
|
||
ej := predict_raw(model, data[j].x) - f64(data[j].y) | ||
|
||
old_alpha_i, old_alpha_j := alphas[i], alphas[j] | ||
l, h := compute_l_h(config.c, alphas[i], alphas[j], data[i].y, data[j].y) | ||
|
||
if l == h { | ||
continue | ||
} | ||
|
||
eta := 2 * kernel(data[i].x, data[j].x) - kernel(data[i].x, data[i].x) - kernel(data[j].x, | ||
data[j].x) | ||
if eta >= 0 { | ||
continue | ||
} | ||
|
||
alphas[j] -= f64(data[j].y) * (ei - ej) / eta | ||
alphas[j] = math.max(l, math.min(h, alphas[j])) | ||
|
||
if math.abs(alphas[j] - old_alpha_j) < 1e-5 { | ||
continue | ||
} | ||
|
||
alphas[i] += f64(data[i].y * data[j].y) * (old_alpha_j - alphas[j]) | ||
|
||
b1 := b - ei - data[i].y * (alphas[i] - old_alpha_i) * kernel(data[i].x, | ||
data[i].x) - data[j].y * (alphas[j] - old_alpha_j) * kernel(data[i].x, | ||
data[j].x) | ||
b2 := b - ej - data[i].y * (alphas[i] - old_alpha_i) * kernel(data[i].x, | ||
data[j].x) - data[j].y * (alphas[j] - old_alpha_j) * kernel(data[j].x, | ||
data[j].x) | ||
|
||
if 0 < alphas[i] && alphas[i] < config.c { | ||
b = b1 | ||
} else if 0 < alphas[j] && alphas[j] < config.c { | ||
b = b2 | ||
} else { | ||
b = (b1 + b2) / 2 | ||
} | ||
|
||
alpha_pairs_changed += 1 | ||
} | ||
} | ||
|
||
if alpha_pairs_changed == 0 { | ||
break | ||
} | ||
} | ||
|
||
model.b = b | ||
model.alphas = alphas | ||
mut support_vectors := []DataPoint{} | ||
for i, d in data { | ||
if alphas[i] > 0 { | ||
support_vectors << d | ||
} | ||
} | ||
model.support_vectors = support_vectors | ||
|
||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove or use the kernel_param
field.
The kernel_param
field in the SVMConfig
structure is not used in the current implementation. Consider removing it or using it in the kernel functions.
- kernel_param f64 = 1.0
Committable suggestion was skipped due to low confidence.
pub fn train_multiclass_svm(data []DataPoint, config SVMConfig) &MulticlassSVM { | ||
mut classes := []int{} | ||
for point in data { | ||
if point.y !in classes { | ||
classes << point.y | ||
} | ||
} | ||
|
||
mut models := [][]&SVMModel{len: classes.len, init: []&SVMModel{}} | ||
|
||
for i := 0; i < classes.len; i++ { | ||
models[i] = []&SVMModel{len: classes.len, init: 0} | ||
for j := i + 1; j < classes.len; j++ { | ||
mut binary_data := []DataPoint{} | ||
for point in data { | ||
if point.y == classes[i] || point.y == classes[j] { | ||
binary_data << DataPoint{ | ||
x: point.x | ||
y: if point.y == classes[i] { 1 } else { -1 } | ||
} | ||
} | ||
} | ||
models[i][j] = train_svm(binary_data, config) | ||
} | ||
} | ||
|
||
return &MulticlassSVM{ | ||
models: models | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove or use the kernel_param
field.
The kernel_param
field in the SVMConfig
structure is not used in the current implementation. Consider removing it or using it in the kernel functions.
- kernel_param f64 = 1.0
Committable suggestion was skipped due to low confidence.
pub struct SVMConfig { | ||
pub mut: | ||
max_iterations int = 1000 | ||
learning_rate f64 = 0.01 | ||
tolerance f64 = 1e-6 | ||
c f64 = 1.0 | ||
kernel_type KernelType = .linear | ||
kernel_param f64 = 1.0 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove or use the kernel_param
field.
The kernel_param
field is defined but not used in the current implementation. Consider removing it or using it in the kernel functions.
- kernel_param f64 = 1.0
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
pub struct SVMConfig { | |
pub mut: | |
max_iterations int = 1000 | |
learning_rate f64 = 0.01 | |
tolerance f64 = 1e-6 | |
c f64 = 1.0 | |
kernel_type KernelType = .linear | |
kernel_param f64 = 1.0 | |
} | |
pub struct SVMConfig { | |
pub mut: | |
max_iterations int = 1000 | |
learning_rate f64 = 0.01 | |
tolerance f64 = 1e-6 | |
c f64 = 1.0 | |
kernel_type KernelType = .linear | |
} |
Kernel Functions:a) Linear Kernel:[ K(x,y) = x^T y ] b) Polynomial Kernel:[ K(x,y) = (x^T y + 1)^d ] c) RBF (Radial Basis Function) Kernel:[ K(x,y) = \exp(-\gamma ||x-y||^2) ] d) Quadratic Kernel:[ K(x,y) = (x^T y)^2 ] SVM Optimization Problem:The dual formulation of SVM is as follows: Constraints:[ 0 \leq \alpha_i \leq C, \quad \forall i ] Decision Function:[ f(x) = \sum_{i=1}^n \alpha_i y_i K(x_i, x) + b ] SMO (Sequential Minimal Optimization) Algorithm:The SMO algorithm implemented in the code updates two Lagrange multipliers in each iteration: a)[ \alpha_j^{new} = \alpha_j^{old} - \frac{y_j(E_i - E_j)}{\eta} ] b)[ \alpha_i^{new} = \alpha_i^{old} + y_i y_j (\alpha_j^{old} - \alpha_j^{new}) ] Multi-Class SVM:The multi-class SVM approach implemented in the code uses the "one-vs-one" strategy. This involves training one SVM for each pair of classes. The total number of SVMs is: These mathematical expressions represent the core components of the SVM algorithm implemented in the code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- ml/decision_tree.v (1 hunks)
- ml/decision_tree_test.v (1 hunks)
Additional comments not posted (21)
ml/decision_tree_test.v (6)
5-11
: LGTM!The test correctly verifies the creation of a
DecisionTree
instance with the specified parameters.
13-20
: LGTM!The test correctly verifies the creation of a
Dataset
instance with the specified parameters and ensures the samples list is initialized correctly.
22-34
: LGTM!The test correctly verifies the addition of samples to the
Dataset
and handles invalid labels appropriately.
36-47
: LGTM!The test correctly verifies the entropy calculation of the
Dataset
by comparing it with a manually calculated value.
49-62
: LGTM!The test correctly verifies the training and prediction functionalities of the
DecisionTree
.
64-84
: LGTM!The test correctly verifies the information gain calculation by comparing it with a manually calculated value.
ml/decision_tree.v (15)
5-10
: LGTM!The
Sample
struct correctly defines the features and label.
12-18
: LGTM!The
Dataset
struct correctly defines the samples, number of features, and number of classes.
20-27
: LGTM!The
Node
struct correctly defines the feature, threshold, label, and child nodes.
29-34
: LGTM!The
DecisionTree
struct correctly defines the root node, maximum depth, and minimum samples split.
36-42
: LGTM!The function correctly initializes a new instance of
DecisionTree
with the specified parameters.
44-52
: LGTM!The function correctly finds the index of the maximum value in an array.
54-60
: LGTM!The function correctly initializes a new instance of
Dataset
with the specified parameters.
62-71
: LGTM!The method correctly adds a sample to the
Dataset
and handles invalid labels.
73-87
: LGTM!The method correctly calculates the entropy of the
Dataset
using the class counts.
89-124
: LGTM!The function correctly finds the best split for the
Dataset
by iterating through features and samples, and calculating information gain.
126-180
: LGTM!The function correctly builds the decision tree recursively by finding the best split and creating child nodes.
182-184
: LGTM!The method correctly trains the decision tree by building the tree using the dataset.
186-188
: LGTM!The method correctly predicts the class label for the given features by traversing the tree.
190-200
: LGTM!The function correctly predicts the class label recursively by traversing the tree.
202-207
: LGTM!The function correctly calculates the information gain for a split using the entropy values of the parent, left, and right datasets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (2)
- ml/random_forest.v (1 hunks)
- ml/random_forest_test.v (1 hunks)
Additional comments not posted (13)
ml/random_forest_test.v (5)
5-12
: LGTM!The function correctly tests the initialization of the RandomForest instance.
14-24
: LGTM!The function correctly tests the bootstrap sampling method.
26-32
: LGTM!The function correctly tests the feature subset selection method.
34-54
: LGTM!The function correctly tests the training and prediction methods of the RandomForest class.
56-62
: LGTM!The function correctly runs all the test functions and prints a success message.
ml/random_forest.v (8)
14-22
: LGTM!The method correctly initializes a RandomForest instance with the given parameters.
24-32
: LGTM!The method correctly creates a bootstrap sample from the given dataset.
34-38
: LGTM!The method correctly selects a subset of features from the given number of features.
40-50
: LGTM!The method correctly trains the RandomForest instance using the given dataset.
52-65
: LGTM!The method correctly predicts the class label for the given features using the trained RandomForest instance.
67-70
: LGTM!The method correctly trains the DecisionTree instance using a subset of features.
72-127
: LGTM!The method correctly builds the decision tree using a subset of features.
130-165
: LGTM!The method correctly finds the best split for a dataset using a subset of features.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@suleyman-kaya thanks a lot for creating this PR! I just sent a message to you on our discord challenge 😊
SVM Module
This PR introduces a complete implementation of a Support Vector Machine (SVM) module built from scratch. The module includes several key components and functions to support SVM training and prediction, as well as multiclass classification.
Core Structures
SVMConfig
Structuremax_iterations
: Specifies the maximum number of iterations for training. Default: 1000.learning_rate
: Sets the learning rate for the training process. Default: 0.01.tolerance
: Defines the tolerance for stopping criteria. Default: 1e-6.c
: The regularization parameter for the SVM. Default: 1.0.kernel_type
: The type of kernel to be used (KernelType enum).kernel_param
: Parameter for the kernel function.DataPoint
Structurex
: A slice ([]f64
) representing feature values of the data point.y
: An integer representing the class label of the data point.SVMModel
Structuresupport_vectors
: A slice ofDataPoint
structures representing the support vectors identified during training.alphas
: A slice ([]f64
) containing the alpha (α) values for each support vector.b
: The bias term in the decision function.kernel
: A function of typeKernelFunction
that specifies the kernel function used for transforming the input data.config
: An instance ofSVMConfig
that holds configuration parameters for training.Key Functions
linear_kernel(x []f64, y []f64) f64
: Computes the linear kernel between two vectors.polynomial_kernel(degree f64) KernelFunction
: Returns a polynomial kernel function with a specified degree.rbf_kernel(gamma f64) KernelFunction
: Returns a radial basis function (RBF) kernel with a specified gamma parameter.quadratic_kernel(x []f64, y []f64) f64
: Computes the quadratic kernel between two vectors.custom_kernel(x []f64, y []f64) f64
: A custom kernel function for specific use cases.train_svm(data []DataPoint, config SVMConfig) &SVMModel
: Trains an SVM model using the provided data and configuration.predict_raw(model &SVMModel, x []f64) f64
: Computes the raw prediction score for a data point.predict(model &SVMModel, x []f64) int
: Predicts the class label for a data point.train_multiclass_svm(data []DataPoint, config SVMConfig) &MulticlassSVM
: Trains a multiclass SVM model using a one-vs-one approach.predict_multiclass(model &MulticlassSVM, x []f64) int
: Predicts the class label for a data point using the multiclass SVM model.Tests Added
test_linear_kernel()
: Validates the linear kernel function.test_polynomial_kernel()
: Validates the polynomial kernel function.test_rbf_kernel()
: Validates the radial basis function (RBF) kernel.test_quadratic_kernel()
: Verifies the quadratic kernel function.test_custom_kernel()
: Tests the custom kernel function.test_dot_product()
: Checks the dot product calculation.test_vector_subtract()
: Ensures correct vector subtraction.test_svm()
: Tests the SVM training and prediction for binary classification.test_multiclass_svm()
: Validates the training and prediction for multiclass classification.These comprehensive tests cover various aspects of the SVM implementation:
All tests use assert statements to verify the correctness of the results. The main function runs all these tests sequentially, providing a comprehensive validation of the SVM module's functionality.
This test suite ensures the reliability and accuracy of the SVM implementation across different kernel types and classification scenarios.
Summary by CodeRabbit
New Features
Bug Fixes
Documentation