oop_and_ml/src/eval_flatten.cpp

27 lines
No EOL
848 B
C++
Executable file

#include "eval_flatten.h"
eval_flatten::eval_flatten(const expression &expr): eval_op(expr) {
}
void eval_flatten::eval(vars_type &variables, const kwargs_type &kwargs) {
// Check that inputs_ is not empty
assert(!inputs_.empty());
// Retrieve the tensor from variables using the first input expression ID
const int inputExprId = inputs_[0];
tensor t = variables.at(inputExprId);
// For debugging: print the input tensor
std::cout << "Input Tensor for Flatten (expr_id " << inputExprId << "):" << std::endl;
//t.print();
// Perform the Flatten operation
tensor flattenResult = Flatten(t);
// Store the result in variables under the key expr_id_
variables[expr_id_] = flattenResult;
}
std::shared_ptr<eval_op> eval_flatten::clone(const expression &expr) {
return std::make_shared<eval_flatten>(expr);
}