diff --git a/Cargo.lock b/Cargo.lock index 409f6bd..c252c9b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -433,6 +433,7 @@ checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ "bitflags 2.8.0", "libc", + "redox_syscall", ] [[package]] @@ -491,6 +492,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "numtoa" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6aa2c4e539b869820a2b82e1aef6ff40aa85e65decdd5185e83fb4b1249cd00f" + [[package]] name = "once_cell" version = "1.20.3" @@ -505,6 +512,7 @@ dependencies = [ "enum_dispatch", "font-kit", "raylib", + "termion", "test-case", ] @@ -626,6 +634,12 @@ dependencies = [ "bitflags 2.8.0", ] +[[package]] +name = "redox_termios" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20145670ba436b55d91fc92d25e71160fbfbdd57831631c8d7d36377a476f1cb" + [[package]] name = "redox_users" version = "0.4.6" @@ -794,6 +808,18 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "termion" +version = "4.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eaa98560e51a2cf4f0bb884d8b2098a9ea11ecf3b7078e9c68242c74cc923a7" +dependencies = [ + "libc", + "libredox", + "numtoa", + "redox_termios", +] + [[package]] name = "test-case" version = "3.3.1" diff --git a/Cargo.toml b/Cargo.toml index dca49d3..c894d61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,12 +9,20 @@ clay-layout = { path = "../clay-ui-rust" } enum_dispatch = "0.3.13" font-kit = "0.14.2" raylib = { version = "5.0.2", features = ["wayland"] } +termion = "4.0.3" test-case = "3.3.1" +[features] +async = [] + [lib] name = "libopenbirch" path = "src/lib/lib.rs" +# [[bin]] +# name = "openbirch-gui" +# path = "src/app/bin.rs" + [[bin]] -name = "openbirch_iced" -path = "src/app/bin.rs" +name = "openbirch-repl" +path = "src/cli-repl.rs" diff --git a/flake.lock b/flake.lock index e666a7d..ce2c29c 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1739020877, - "narHash": "sha256-mIvECo/NNdJJ/bXjNqIh8yeoSjVLAuDuTUzAo7dzs8Y=", + "lastModified": 1739866667, + "narHash": "sha256-EO1ygNKZlsAC9avfcwHkKGMsmipUk1Uc0TbrEZpkn64=", "owner": "nixos", "repo": "nixpkgs", - "rev": "a79cfe0ebd24952b580b1cf08cd906354996d547", + "rev": "73cf49b8ad837ade2de76f87eb53fc85ed5d4680", "type": "github" }, "original": { @@ -29,11 +29,11 @@ ] }, "locked": { - "lastModified": 1736216977, - "narHash": "sha256-EMueGrzBpryM8mgOyoyJ7DdNRRk09ug1ggcLLp0WrCQ=", + "lastModified": 1739932111, + "narHash": "sha256-WkayjH0vuGw0hx2gmjTUGFRvMKpM17gKcpL/U8EUUw0=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "bbe7e4e7a70d235db4bbdcabbf8a2f6671881dd7", + "rev": "75b2271c5c087d830684cd5462d4410219acc367", "type": "github" }, "original": { diff --git a/src/cli-repl.rs b/src/cli-repl.rs new file mode 100644 index 0000000..0c86d17 --- /dev/null +++ b/src/cli-repl.rs @@ -0,0 +1,200 @@ +use std::io::{self, StdoutLock, Write, stdout}; + +use libopenbirch::environment::Environment; +use libopenbirch::node::Node; +use libopenbirch::parser::{Lexer, Parser}; +#[cfg(feature = "async")] +use termion::AsyncReader; +use termion::color; +use termion::event::Key; +use termion::input::TermReadEventsAndRaw; +use termion::raw::{IntoRawMode, RawTerminal}; + +pub struct Input { + #[cfg(not(feature = "async"))] + stdin: std::io::Stdin, + + stdout: RawTerminal>, + buffer: String, + current_char: usize, +} +impl Input { + pub fn new() -> Self { + #[cfg(not(feature = "async"))] + let stdin = std::io::stdin(); + let stdout = stdout().lock().into_raw_mode().unwrap(); + + Self { + stdin, + stdout, + buffer: "".into(), + current_char: 0, + } + } + + #[inline] + fn draw_line( + buffer: &String, + stdout: &mut RawTerminal>, + current_char: usize, + ) { + print!("\r{}> {}", termion::clear::CurrentLine, buffer,); + let left_diff = buffer.len() - current_char; + if left_diff > 0 { + print!("{}", termion::cursor::Left(left_diff.try_into().unwrap())); + } + let _ = stdout.flush(); + } + + /// Gets input from `io::stdin`. + /// Returns `None` when the user presses or types `:quit` + pub fn get(&mut self) -> Result, io::Error> { + Self::draw_line(&self.buffer, &mut self.stdout, self.current_char); + loop { + for b in (&mut self.stdin).events_and_raw() { + let (event, _slice) = b?; + + match event { + termion::event::Event::Key(key) => match key { + Key::Char(char) if char == '\n' => { + if &self.buffer == ":quit" { + println!("\r"); + return Err(io::Error::from(io::ErrorKind::Interrupted)); + } + + let r = self.buffer.clone(); + self.buffer.clear(); + self.current_char = 0; + print!("\n\r"); + return Ok(Some(r)); + } + Key::Backspace => { + if self.current_char == 0 { + continue; + } + self.current_char -= 1; + self.buffer.remove(self.current_char); + Self::draw_line(&self.buffer, &mut self.stdout, self.current_char); + let _ = self.stdout.flush(); + } + Key::Delete => { + if self.current_char == self.buffer.len() { + continue; + } + self.buffer.remove(self.current_char); + Self::draw_line(&self.buffer, &mut self.stdout, self.current_char); + let _ = self.stdout.flush(); + } + Key::Left => { + if self.current_char == 0 { + continue; + } + self.current_char -= 1; + print!("{}", termion::cursor::Left(1)); + let _ = self.stdout.flush(); + } + Key::Right => { + if self.current_char == self.buffer.len() { + continue; + } + self.current_char += 1; + print!("{}", termion::cursor::Right(1)); + let _ = self.stdout.flush(); + } + Key::Char(char) => { + self.buffer.insert(self.current_char, char); + self.current_char += 1; + Self::draw_line(&self.buffer, &mut self.stdout, self.current_char); + } + Key::Ctrl(c) if c == 'c' => { + println!("\r"); + return Err(io::Error::from(io::ErrorKind::Interrupted)); + } + _ => { + #[cfg(debug_assertions)] + return Err(io::Error::other(format!( + "Key {key:?} is not implemented" + ))); + } + }, + termion::event::Event::Mouse(mouse_event) => {} + termion::event::Event::Unsupported(items) => {} + } + } + } + } + + pub fn disable_raw(&mut self) { + self.stdout.suspend_raw_mode(); + } + + pub fn enable_raw(&mut self) { + self.stdout.activate_raw_mode(); + } +} + +fn print_err(i: usize, exp: String) { + println!("\r{}{}^", " ".repeat(i + 2), color::Fg(color::Yellow)); + println!( + "\r{}{}{}", + color::Fg(color::Red), + exp, + color::Fg(color::Reset) + ); +} + +fn main() -> Result<(), io::Error> { + let mut input = Input::new(); + + let mut env = Environment::new(); + + while let Some(source) = input.get()? { + let mut lexer = Lexer::new(&source); + let tokens_result = lexer.lex(); + + if tokens_result.is_err() { + match tokens_result.err().unwrap() { + libopenbirch::parser::LexerError::UnexpectedChar(i, exp) => print_err(i, exp), + } + continue; + } + + #[cfg(debug_assertions)] + input.disable_raw(); + + let tokens = tokens_result.unwrap(); + let mut parser = Parser::new(tokens); + + #[cfg(debug_assertions)] + input.enable_raw(); + + let nodes = match parser.parse() { + Ok(nodes) => nodes, + Err(err) => { + match err { + libopenbirch::parser::ParserError::UnexpectedEndOfTokens(exp) => { + print_err(source.len(), exp) + } + libopenbirch::parser::ParserError::UnexpectedToken(i, exp) => print_err(i, exp), + } + continue; + } + }; + + print!("{}", color::Fg(color::Blue)); + + for node in nodes { + let evaluated = node.evaluate(&mut env); + match evaluated { + Ok(result) => println!("\r\t{}", result.as_string(None)), + Err(exp) => print_err(0, exp), + } + } + + print!("{}", color::Fg(color::Reset)); + + // println!("\t{}{source}{}", termion::color::Fg(termion::color::Blue), termion::color::Fg(termion::color::Reset)); + } + + Ok(()) +} diff --git a/src/lib/lib.rs b/src/lib/lib.rs index 6b442e3..44619f1 100644 --- a/src/lib/lib.rs +++ b/src/lib/lib.rs @@ -2,6 +2,7 @@ use std::io::{self, stdout, BufRead, Write}; pub mod node; pub mod environment; +pub mod parser; #[cfg(test)] mod tests; diff --git a/src/lib/node/add.rs b/src/lib/node/add.rs index a91944d..be0607b 100644 --- a/src/lib/node/add.rs +++ b/src/lib/node/add.rs @@ -41,10 +41,10 @@ impl Node for Add { } impl Add { - pub fn new(left: NodeEnum, right: NodeEnum) -> Self { + pub fn new(left: Rc, right: Rc) -> Self { Self { - left: Rc::new(left), - right: Rc::new(right), + left, + right, } } diff --git a/src/lib/node/call.rs b/src/lib/node/call.rs index 16c9f02..b470c5e 100644 --- a/src/lib/node/call.rs +++ b/src/lib/node/call.rs @@ -1,57 +1,67 @@ use std::rc::Rc; -use crate::environment::Environment; +use crate::{environment::Environment, node::function::FunctionType}; use super::{Node, NodeEnum, Precedence, symbol::Symbol}; #[derive(Debug, Clone, PartialEq, PartialOrd)] pub struct Call { - left: Rc, - right: Vec>, + function: Rc, + arguments: Vec>, } impl Node for Call { fn evaluate(&self, env: &mut Environment) -> Result, String> { - todo!(); - let function = match self.left.as_ref() { - NodeEnum::Symbol(symbol) => match symbol.evaluate(env)?.as_ref() { - _ => { - return Err(format!( - "Cannot call {} as a function", - symbol.as_string(Some(env)) - )); - } - - NodeEnum::Function(function) => function, - }, - NodeEnum::Function(function) => function, - NodeEnum::Call(call) => match call.evaluate(env)?.as_ref() { - NodeEnum::Function(function) => function, - - // FIXME: This might fail for long chains of calls - _ => { - return Err(format!( - "Cannot call {} as a function", - self.left.as_string(None) - )); - } - }, - _ => { - return Err(format!( - "Cannot call {} as a function", - self.left.as_string(None) - )); - } + // Evaluate callee and error if its not a function + let evaluated = self.function.evaluate(env)?; + let func = if let NodeEnum::Function(func) = evaluated.as_ref() { + func + } else { + return Err(format!( + "Cannot call {} as a function", + evaluated.as_string(Some(env)) + )); }; - // match function { - // - // } - todo!() + // Check if argument counts match + let fargs = func.get_arguments(); + if fargs.len() != self.arguments.len() { + return Err(format!( + "Error calling function. Expected {} arguments, but got {}", + func.get_arguments().len(), + self.arguments.len() + )); + } + + // Call function body with arguments + match func.get_body() { + // Pass arguments to native function + FunctionType::Native(_name, native_function) => native_function(&self.arguments), + FunctionType::UserFunction(node_enum) => { + // TODO: Push scope + // Define variables + fargs + .iter() + .zip(&self.arguments) + .for_each(|(symbol, value)| { + env.insert(symbol.get_value(), value.clone()); + }); + let ev = node_enum.evaluate(env); + // TODO: Pop scope + // Return evaluated return value for function + ev + } + } } - fn as_string(&self, _: Option<&Environment>) -> String { - todo!() + fn as_string(&self, env: Option<&Environment>) -> String { + let arguments = self + .arguments + .iter() + .map(|x| x.as_string(env)) + .reduce(|a, b| a + ", " + &b) + .unwrap(); + format!("{}({})", self.function.as_string(env), arguments) } fn precedence(&self) -> Precedence { diff --git a/src/lib/node/constant.rs b/src/lib/node/constant.rs index c1dd79c..3797100 100644 --- a/src/lib/node/constant.rs +++ b/src/lib/node/constant.rs @@ -2,9 +2,11 @@ use std::rc::Rc; use super::{Environment, Node, Precedence}; +pub type ConstantValue = f64; + #[derive(Clone, Debug, PartialEq, PartialOrd)] pub struct Constant { - value: f64, + value: ConstantValue, } impl Node for Constant { @@ -22,15 +24,15 @@ impl Node for Constant { } impl Constant { - pub fn new(value: f64) -> Self { + pub fn new(value: ConstantValue) -> Self { Self { value } } - pub fn get_value(&self) -> f64 { + pub fn get_value(&self) -> ConstantValue { self.value } - pub fn set_value(&mut self, value: f64) { + pub fn set_value(&mut self, value: ConstantValue) { self.value = value; } } diff --git a/src/lib/node/divide.rs b/src/lib/node/divide.rs index 8c45245..f2bcf83 100644 --- a/src/lib/node/divide.rs +++ b/src/lib/node/divide.rs @@ -41,10 +41,10 @@ impl Node for Divide { } impl Divide { - pub fn new(left: NodeEnum, right: NodeEnum) -> Self { + pub fn new(left: Rc, right: Rc) -> Self { Self { - left: Rc::new(left), - right: Rc::new(right), + left, + right } } diff --git a/src/lib/node/function.rs b/src/lib/node/function.rs index 64d8bf4..f4897c9 100644 --- a/src/lib/node/function.rs +++ b/src/lib/node/function.rs @@ -8,7 +8,7 @@ use super::{Node, NodeEnum, Precedence, symbol::Symbol}; pub enum FunctionType { Native( &'static str, - fn(Vec>) -> Result, String>, + fn(&Vec>) -> Result, String>, ), UserFunction(Rc), } @@ -50,4 +50,12 @@ impl Function { // pub fn new(t: FunctionType, ) -> Self { // // } + + pub fn get_body(&self) -> &FunctionType { + &self.function + } + + pub fn get_arguments(&self) -> &Vec { + &self.arguments + } } diff --git a/src/lib/node/if_else.rs b/src/lib/node/if_else.rs index 2ea3e09..07e7495 100644 --- a/src/lib/node/if_else.rs +++ b/src/lib/node/if_else.rs @@ -8,6 +8,24 @@ pub enum Bool { False, } +impl Node for Bool { + fn evaluate(&self, _: &mut Environment) -> Result, String> { + Ok(Rc::new(self.clone().into())) + } + + fn as_string(&self, _: Option<&Environment>) -> String { + match self { + Bool::True => "true", + Bool::False => "false", + } + .into() + } + + fn precedence(&self) -> Precedence { + Precedence::Primary + } +} + #[derive(Clone, Debug, PartialEq, PartialOrd)] pub enum ElseBranchEnum { ElseIf(Rc), @@ -25,12 +43,37 @@ impl Node for IfElse { fn evaluate(&self, env: &mut Environment) -> Result, String> { let condition_evaluated = self.condition.evaluate(env)?; - if let super::NodeEnum::Bool(bool) = condition_evaluated { + let condition = if let NodeEnum::Bool(bool) = condition_evaluated.as_ref() { + bool } else { return Err(format!( "Cannot evaluate {} to a bool", condition_evaluated.as_string(Some(env)) )); + }; + + fn evaluate_block( + block: &Vec>, + env: &mut Environment, + ) -> Result, String> { + // TODO: Push new scope + if let Some((last, to_evaluate)) = block.split_last() { + for expr in to_evaluate { + expr.evaluate(env)?; + } + last.evaluate(env) + } else { + Err("Empty if statemenent true branch")? + } + // TODO: Pop scope + } + + match condition { + Bool::True => evaluate_block(&self.true_branch, env), + Bool::False => match &self.else_branch { + ElseBranchEnum::ElseIf(if_else) => if_else.evaluate(env), + ElseBranchEnum::Block(node_enums) => evaluate_block(node_enums, env), + }, } } diff --git a/src/lib/node/mod.rs b/src/lib/node/mod.rs index ca5a2a6..7d68104 100644 --- a/src/lib/node/mod.rs +++ b/src/lib/node/mod.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fmt::Display, rc::Rc}; +use std::{fmt::Display, rc::Rc}; use add::Add; use assign::Assign; @@ -10,7 +10,6 @@ use enum_dispatch::enum_dispatch; use function::Function; use if_else::{Bool, IfElse}; use multiply::Multiply; -use node_ref::NodeRef; use subtract::Subtract; use symbol::Symbol; @@ -51,10 +50,12 @@ pub enum NodeEnum { Bool, IfElse, - // Logical + // Logical operators // In, - // Or, + // Where, // Not, + // Or, + // And } #[derive(PartialEq, Eq, PartialOrd, Ord)] diff --git a/src/lib/node/multiply.rs b/src/lib/node/multiply.rs index 7a4a28e..ffe632a 100644 --- a/src/lib/node/multiply.rs +++ b/src/lib/node/multiply.rs @@ -41,10 +41,10 @@ impl Node for Multiply { } impl Multiply { - pub fn new(left: NodeEnum, right: NodeEnum) -> Self { + pub fn new(left: Rc, right: Rc) -> Self { Self { - left: Rc::new(left), - right: Rc::new(right), + left, + right } } diff --git a/src/lib/node/subtract.rs b/src/lib/node/subtract.rs index 3315ba3..70ca3bc 100644 --- a/src/lib/node/subtract.rs +++ b/src/lib/node/subtract.rs @@ -41,10 +41,10 @@ impl Node for Subtract { } impl Subtract { - pub fn new(left: NodeEnum, right: NodeEnum) -> Self { + pub fn new(left: Rc, right: Rc) -> Self { Self { - left: Rc::new(left), - right: Rc::new(right), + left, + right, } } diff --git a/src/lib/parser/mod.rs b/src/lib/parser/mod.rs new file mode 100644 index 0000000..c553034 --- /dev/null +++ b/src/lib/parser/mod.rs @@ -0,0 +1,231 @@ +use std::{collections::HashMap, iter::Peekable, rc::Rc, slice::Iter, vec::IntoIter}; + +use crate::node::{ + NodeEnum, + add::Add, + constant::{Constant, ConstantValue}, + divide::Divide, + multiply::Multiply, + subtract::Subtract, +}; + +#[derive(Debug)] +pub struct Token(usize, TokenType); + +#[derive(Debug)] +pub enum TokenType { + // Space, + Number(ConstantValue), + + Plus, + Minus, + Star, + Slash, + + RParen, + LParen, + + If, + Then, + Else, + End, +} + +pub struct Lexer<'a> { + source: &'a String, +} + +#[derive(Debug)] +pub enum LexerError { + UnexpectedChar(usize, String), +} + +impl<'a> Lexer<'a> { + pub fn new(source: &'a String) -> Self { + Self { source } + } + + pub fn lex(&'a mut self) -> Result, LexerError> { + let mut src = self.source.chars().peekable(); + let mut i = 0; + let mut tokens = vec![]; + + while let Some(c) = src.next() { + match c { + // Collapse spaces into a single Space token + ' ' => { + while src.peek() == Some(&' ') { + src.next(); + i += 1; + } + // tokens.push(Token(i, TokenType::Space)); + } + // Comments with // + '/' if src.peek() == Some(&'/') => { + while src.next() != Some('\n') { + i += 1; + } + } + // Numbers with decimal points + '0'..'9' | '.' => { + let mut digit = String::from(c); + loop { + let d = src.peek(); + let mut has_decimal = c == '.'; + match d { + Some('0'..'9') => { + digit.push(*d.unwrap()); + src.next(); + i += 1; + } + #[allow(unused_assignments)] // For some reason it thinks has_decimal + // is never read + Some('.') => { + if has_decimal { + return Err(LexerError::UnexpectedChar( + i, + "Invalid digit with multiple decimal points".into(), + )); + } + + digit.push(*d.unwrap()); + has_decimal = true; + } + _ => { + break; + } + } + } + let number = digit.parse::().unwrap(); + + tokens.push(Token(i, TokenType::Number(number))); + } + + '+' => tokens.push(Token(i, TokenType::Plus)), + '-' => tokens.push(Token(i, TokenType::Minus)), + '*' => tokens.push(Token(i, TokenType::Star)), + '/' => tokens.push(Token(i, TokenType::Slash)), + + '(' => tokens.push(Token(i, TokenType::LParen)), + ')' => tokens.push(Token(i, TokenType::RParen)), + + _ => { + return Err(LexerError::UnexpectedChar( + i, + format!("Unexpected char {}", c), + )); + } + } + i += 1; + } + + Ok(tokens) + } +} + +pub enum ParserError { + UnexpectedEndOfTokens(String), + UnexpectedToken(usize, String), +} + +/// Recursive descent parser +pub struct Parser { + tokens: Peekable>, +} + +type Tokens<'a> = Peekable>; + +impl Parser { + pub fn new(tokens: Vec) -> Self { + // #[cfg(debug_assertions)] + // println!("\r{tokens:#?}"); + Self { + tokens: tokens.into_iter().peekable(), + } + } + + // Parse tokens recursively and descendentantly + pub fn parse(&mut self) -> Result>, ParserError> { + let mut expressions = vec![]; + + while self.tokens.peek().is_some() { + expressions.push(self.expression()?); + } + + Ok(expressions) + } + + fn expression(&mut self) -> Result, ParserError> { + self.equality() + } + + fn equality(&mut self) -> Result, ParserError> { + // TODO: Implement equality + self.comparison() + } + + fn comparison(&mut self) -> Result, ParserError> { + // TODO: Implement comparison + self.term() + } + + fn term(&mut self) -> Result, ParserError> { + let expr = self.factor()?; + if let Some(Token(_, TokenType::Plus)) = self.tokens.peek() { + self.tokens.next(); + Ok(Rc::new(Add::new(expr, self.comparison()?).into())) + } else if let Some(Token(_, TokenType::Minus)) = self.tokens.peek() { + self.tokens.next(); + Ok(Rc::new(Subtract::new(expr, self.comparison()?).into())) + } else { + Ok(expr) + } + } + + fn factor(&mut self) -> Result, ParserError> { + let expr = self.unary()?; + if let Some(Token(_, TokenType::Star)) = self.tokens.peek() { + self.tokens.next(); + Ok(Rc::new(Multiply::new(expr, self.comparison()?).into())) + } else if let Some(Token(_, TokenType::Slash)) = self.tokens.peek() { + self.tokens.next(); + Ok(Rc::new(Divide::new(expr, self.comparison()?).into())) + } else { + Ok(expr) + } + } + + fn unary(&mut self) -> Result, ParserError> { + self.exponent() + } + + fn exponent(&mut self) -> Result, ParserError> { + self.call() + } + + fn call(&mut self) -> Result, ParserError> { + self.function() + } + + fn function(&mut self) -> Result, ParserError> { + self.primary() + } + + fn primary(&mut self) -> Result, ParserError> { + let (i, token) = if let Some(Token(i, token)) = self.tokens.next() { + (i, token) + } else { + return Err(ParserError::UnexpectedEndOfTokens( + "Expected a Primary here".into(), + )); + }; + + match token { + TokenType::Number(value) => Ok(Rc::new(Constant::new(value).into())), + _ => Err(ParserError::UnexpectedToken( + i, + format!("Unexpected token {token:?}"), + )), + } + } +} diff --git a/src/lib/tests/mod.rs b/src/lib/tests/mod.rs index 9883234..96e1570 100644 --- a/src/lib/tests/mod.rs +++ b/src/lib/tests/mod.rs @@ -2,7 +2,7 @@ mod arithmetic { use std::rc::Rc; use crate::{environment, node::{ - add::Add, constant::Constant, divide::Divide, multiply::Multiply, subtract::Subtract, Node, NodeEnum + add::Add, constant::{Constant, ConstantValue}, divide::Divide, multiply::Multiply, subtract::Subtract, Node, NodeEnum }}; use environment::Environment; @@ -11,13 +11,13 @@ mod arithmetic { #[test_case(69.0, 420.0, 489.0 ; "when both are positive")] #[test_case(-2.0, -4.0, -6.0 ; "when both are negative")] #[test_case(0.0, 0.0, 0.0 ; "when both are zero")] - #[test_case(f64::INFINITY, 0.0, f64::INFINITY ; "infinity")] - // #[test_case(f64::NAN, 0.0, f64::NAN ; "NaN")] // cant test NaN because NaN != NaN - fn addition(a: f64, b: f64, e: f64) { + #[test_case(ConstantValue::INFINITY, 0.0, ConstantValue::INFINITY ; "infinity")] + // #[test_case(ConstantValue::NAN, 0.0, ConstantValue::NAN ; "NaN")] // cant test NaN because NaN != NaN + fn addition(a: ConstantValue, b: ConstantValue, e: ConstantValue) { let mut env = Environment::new(); - let a = Constant::new(a).into(); - let b = Constant::new(b).into(); + let a = Rc::new(Constant::new(a).into()); + let b = Rc::new(Constant::new(b).into()); let d: NodeEnum = Add::new(a, b).into(); let value = Rc::::try_unwrap(d.evaluate(&mut env).unwrap()).unwrap(); @@ -33,15 +33,15 @@ mod arithmetic { #[test_case(69.0, 420.0, -351.0 ; "when both are positive")] #[test_case(-2.0, -4.0, 2.0 ; "when both are negative")] #[test_case(0.0, 0.0, 0.0 ; "when both are zero")] - #[test_case(f64::INFINITY, 0.0, f64::INFINITY ; "infinity")] - // #[test_case(f64::NAN, 0.0, f64::NAN ; "NaN")] // cant test NaN because NaN != NaN - fn subtraction(aa: f64, bb: f64, e: f64) { + #[test_case(ConstantValue::INFINITY, 0.0, ConstantValue::INFINITY ; "infinity")] + // #[test_case(ConstantValue::NAN, 0.0, ConstantValue::NAN ; "NaN")] // cant test NaN because NaN != NaN + fn subtraction(aa: ConstantValue, bb: ConstantValue, e: ConstantValue) { let mut env = Environment::new(); - let a = Constant::new(0.0).into(); - let b = Constant::new(aa).into(); - let c = Constant::new(bb).into(); - let d: NodeEnum = Subtract::new(Add::new(a, b).into(), c).into(); + let a = Rc::new(Constant::new(0.0).into()); + let b = Rc::new(Constant::new(aa).into()); + let c = Rc::new(Constant::new(bb).into()); + let d: NodeEnum = Subtract::new(Rc::new(Add::new(a, b).into()), c).into(); let value = Rc::::try_unwrap(d.evaluate(&mut env).unwrap()).unwrap(); @@ -58,11 +58,11 @@ mod arithmetic { #[test_case(5.0, -10.0, -50.0 ; "when right is negative")] #[test_case(-5.0, -10.0, 50.0 ; "when both are negative")] #[test_case(2734589235234.23, 0.0, 0.0 ; "when 0 is involved")] - fn multiplication(a: f64, b: f64, e: f64) { + fn multiplication(a: ConstantValue, b: ConstantValue, e: ConstantValue) { let mut env = Environment::new(); - let a = Constant::new(a).into(); - let b = Constant::new(b).into(); + let a = Rc::new(Constant::new(a).into()); + let b = Rc::new(Constant::new(b).into()); let d: NodeEnum = Multiply::new(a, b).into(); let value = Rc::::try_unwrap(d.evaluate(&mut env).unwrap()).unwrap(); @@ -79,11 +79,11 @@ mod arithmetic { #[test_case(-5.0, 10.0, -0.5 ; "when left is negative")] #[test_case(5.0, -10.0, -0.5 ; "when right is negative")] #[test_case(-5.0, -10.0, 0.5 ; "when both are negative")] - fn division(a: f64, b: f64, e: f64) { + fn division(a: ConstantValue, b: ConstantValue, e: ConstantValue) { let mut env = Environment::new(); - let a = Constant::new(a).into(); - let b = Constant::new(b).into(); + let a = Rc::new(Constant::new(a).into()); + let b = Rc::new(Constant::new(b).into()); let d: NodeEnum = Divide::new(a, b).into(); let value = Rc::::try_unwrap(d.evaluate(&mut env).unwrap()).unwrap(); @@ -107,8 +107,10 @@ mod functions { mod expected_errors { use test_case::test_case; + use crate::node::constant::ConstantValue; + #[test_case(5.0, 0.0 ; "divide by zero")] - fn division(a: f64, b: f64) { + fn division(a: ConstantValue, b: ConstantValue) { let _ = a+b; } } @@ -116,13 +118,15 @@ mod expected_errors { mod misc { use test_case::test_case; + use crate::node::constant::ConstantValue; + #[test_case(30, '+', 60, '-', 20, "(30+60)-20" ; "add and subtract")] fn convert_to_string( - a: impl Into, + a: impl Into, op1: char, - b: impl Into, + b: impl Into, op2: char, - c: impl Into, + c: impl Into, e: &'static str, ) { }