use std::{collections::{BTreeMap, BTreeSet}, result, vec}; use crate::{ast::types::{BlockStmt, CompileUnit, Expr, FuncDeclStmt, GlobalDeclStmt, ReturnStmt, Statement, VarDeclStmt}, diagnostic::Diagnositics, ir::{err::IRError, types::{BinaryOp, Function, IRInstr, IRType, MoveRValue, Variable, VariableType}}}; pub struct Generator { var_manager: VariableManager, function_map: BTreeMap, current_func_return_type: Option, diagnostic: Diagnositics, } impl Generator { pub fn new() -> Self { let mut function_map = BTreeMap::new(); function_map.insert("putint".to_string(), Function { name: "putint".to_string(), parameter_types: vec![IRType::I32], return_type: IRType::Void, }); Self { var_manager: VariableManager::new(), current_func_return_type: None, diagnostic: Diagnositics::new(), function_map, } } pub fn emit(&mut self, compile_unit: CompileUnit) -> Vec { self.generate_compile_unit(compile_unit) } fn generate_compile_unit(&mut self, compile_unit: CompileUnit) -> Vec { let mut instrs = vec![]; use GlobalDeclStmt::*; for decl in compile_unit.global_decls { match decl { VarDecl(var_decl) => { instrs.extend(self.generate_var_decl(var_decl, true)); } FuncDecl(func_decl) => { instrs.extend(self.generate_func_decl(func_decl)); } } } instrs } fn generate_var_decl(&mut self, var_decl: VarDeclStmt, is_global: bool) -> Vec { let mut instrs = vec![]; let var_type = if is_global { VariableType::Global } else { VariableType::Local }; for value in var_decl.values { match self.var_manager.declare_variable(&value.name, var_type, value.var_type.into()) { Ok(var) => { if is_global { instrs.push(IRInstr::Declare(var)); } } Err(e) => { self.diagnostic.add_from_ir_error(e, value.span); } } } instrs } fn generate_func_decl(&mut self, func_decl: FuncDeclStmt) -> Vec { if self.function_map.contains_key(&func_decl.name) { self.diagnostic.add_from_ir_error(IRError::FunctionHasBeenDefined(func_decl.name.clone()), func_decl.span); return vec![]; } self.current_func_return_type = Some(func_decl.return_type.into()); self.var_manager.enter_scope(); let parameters: Vec = match func_decl.params.iter().map(|param| { match self.var_manager.declare_variable(¶m.name, VariableType::Local, param.param_type.into()) { Ok(var) => Ok(var), Err(e) => { self.diagnostic.add_from_ir_error(e, param.span); Err(()) } } }).collect() { Ok(p) => p, Err(()) => return vec![], }; let temp_parameters = parameters.iter().map(|param| self.var_manager.declare_param_temp(param.data_type)).collect::>(); let mut body_instrs = vec![]; let block_instrs = self.generate_block_stmt(func_decl.body); for var in self.var_manager.get_cur_func_variables() { if matches!(var.var_type, VariableType::ParamTemp) { continue; } body_instrs.push(IRInstr::Declare(var)); } body_instrs.push(IRInstr::Entry); parameters.iter().zip(temp_parameters.iter()).for_each(|(param, temp_param)| { body_instrs.push(IRInstr::Move(*temp_param, MoveRValue::Var(*param))); }); body_instrs.extend(block_instrs); self.var_manager.exit_scope(); self.current_func_return_type = None; self.var_manager.clear_local_counter(); let func = Function { name: func_decl.name, parameter_types: temp_parameters.iter().map(|v| v.data_type).collect(), return_type: func_decl.return_type.into(), }; self.function_map.insert(func.name.clone(), func.clone()); vec![IRInstr::DefineFunc(func, parameters, body_instrs)] } fn generate_block_stmt(&mut self, block_stmt: BlockStmt) -> Vec { let mut instrs = vec![]; for stmt in block_stmt.statements { instrs.extend(self.generate_statement(stmt)); } instrs } fn generate_statement(&mut self, stmt: Statement) -> Vec { use Statement::*; let instrs = match stmt { Return(return_stmt) => self.generate_return_stmt(return_stmt), Block(block_stmt) => { self.var_manager.enter_scope(); let block_instrs = self.generate_block_stmt(block_stmt); self.var_manager.exit_scope(); block_instrs }, Expr(expr) => self.generate_expr(expr).0, VarDecl(var_decl) => self.generate_var_decl(var_decl, false), }; instrs } fn generate_return_stmt(&mut self, return_stmt: ReturnStmt) -> Vec { let mut instrs = vec![]; match return_stmt.value { Some(expr) => { let (value_instrs, value_var) = self.generate_expr(expr); instrs.extend(value_instrs); if value_var.is_some() { instrs.push(IRInstr::Exit(value_var)); } } None => instrs.push(IRInstr::Exit(None)), } instrs } fn generate_expr(&mut self, expr: Expr) -> (Vec, Option) { use crate::ast::types::ExprValue; match expr.value { ExprValue::IntLit(i) => { // TODO: convert check let var = self.var_manager.declare_temp(IRType::I32); (vec![IRInstr::Move(var, MoveRValue::ConstInt(i as i32))], Some(var)) }, ExprValue::Var(name) => { if let Some(var) = self.var_manager.get_variable(&name) { (vec![], Some(var)) } else { self.diagnostic.add_from_ir_error(IRError::VariableNotFound(name.clone()), expr.span); (vec![], None) } }, ExprValue::Assign { lvalue, rvalue } => { // TODO: only support simple variable assignment now, need to support more complex lvalue in the future if let ExprValue::Var(name) = lvalue.value { let var = match self.var_manager.get_variable(&name) { Some(var) => var, None => { self.diagnostic.add_from_ir_error(IRError::VariableNotFound(name.clone()), lvalue.span); return (vec![], None); } }; let (mut instrs, rvalue_var) = self.generate_expr(*rvalue); if rvalue_var.is_none() { return (vec![], None); } let rvalue_var = rvalue_var.unwrap(); if var.data_type != rvalue_var.data_type { self.diagnostic.add_from_ir_error(IRError::TypeMismatch(var.data_type, rvalue_var.data_type), lvalue.span); return (vec![], None); } instrs.push(IRInstr::Move(var, MoveRValue::Var(rvalue_var))); let temp_var = self.var_manager.declare_temp(var.data_type); instrs.push(IRInstr::Move(temp_var, MoveRValue::Var(var))); (instrs, Some(temp_var)) } else { self.diagnostic.add_from_ir_error(IRError::InvalidAssignmentTarget, lvalue.span); return (vec![], None); } }, ExprValue::BinaryOp { lhs, op, rhs } => { let lhs_span = lhs.span; let rhs_span = rhs.span; let (mut instrs, left_var) = self.generate_expr(*lhs); if left_var.is_none() { return (vec![], None); } let left_var = left_var.unwrap(); let (right_instrs, right_var) = self.generate_expr(*rhs); if right_var.is_none() { return (vec![], None); } let right_var = right_var.unwrap(); instrs.extend(right_instrs); let mut has_void = false; if matches!(left_var.data_type, IRType::Void) { self.diagnostic.add_from_ir_error(IRError::InvalidOperand(IRType::Void), lhs_span); has_void = true; } if matches!(right_var.data_type, IRType::Void) { self.diagnostic.add_from_ir_error(IRError::InvalidOperand(IRType::Void), rhs_span); has_void = true; } if has_void { return (vec![], None); } if left_var.data_type != right_var.data_type { self.diagnostic.add_from_ir_error(IRError::IncompatiableOperand(left_var.data_type, right_var.data_type), lhs_span); self.diagnostic.add_from_ir_error(IRError::InvalidOperand(left_var.data_type), rhs_span); return (vec![], None); } let result_type; match Into::::into(op) { BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod => { result_type = left_var.data_type; }, BinaryOp::Eq | BinaryOp::Ne | BinaryOp::Lt | BinaryOp::Gt | BinaryOp::Le | BinaryOp::Ge => { result_type = IRType::I1; } } let dest_var = self.var_manager.declare_temp(result_type); instrs.push(IRInstr::Binary(dest_var, left_var, op.into(), right_var)); (instrs, Some(dest_var)) }, ExprValue::FuncCall(func_name, args) => { let mut instrs = vec![]; let mut arg_vars = vec![]; let func_def = if let Some(func_def) = self.function_map.get(&func_name) { func_def } else { self.diagnostic.add_from_ir_error(IRError::FunctionNotFound(func_name.clone()), expr.span); return (vec![], None); }.clone(); if args.len() < func_def.parameter_types.len() { self.diagnostic.add_from_ir_error(IRError::TooFewArguments(func_def.parameter_types.len(), args.len()), expr.span); return (vec![], None); } if args.len() > func_def.parameter_types.len() { self.diagnostic.add_from_ir_error(IRError::TooManyArguments(func_def.parameter_types.len(), args.len()), expr.span); return (vec![], None); } let mut has_error = false; for (i, arg) in args.into_iter().enumerate() { let (arg_instrs, arg_var) = self.generate_expr(arg); if arg_var.is_none() { has_error = true; continue; } let arg_var = arg_var.unwrap(); let parameter_type = func_def.parameter_types.get(i).unwrap(); if *parameter_type != arg_var.data_type { self.diagnostic.add_from_ir_error(IRError::TypeMismatch(*parameter_type, arg_var.data_type), expr.span); has_error = true; continue; } instrs.extend(arg_instrs); arg_vars.push(arg_var); } if has_error { return (vec![], None); } let ret_variable = if matches!(func_def.return_type, IRType::Void) { None } else { Some(self.var_manager.declare_temp(func_def.return_type)) }; instrs.push(IRInstr::FuncCall(func_def.clone(), arg_vars, ret_variable)); (instrs, ret_variable) } } } } struct VariableManager { variable_map: BTreeMap>, scopes: Vec>, global_counter: usize, local_counter: usize, local_var_type: Vec, } impl VariableManager { pub fn new() -> Self { Self { variable_map: BTreeMap::new(), scopes: vec![BTreeSet::new()], global_counter: 0, local_counter: 0, local_var_type: vec![], } } pub fn enter_scope(&mut self) { self.scopes.push(BTreeSet::new()); } pub fn exit_scope(&mut self) { let variables = self.scopes.pop().unwrap(); for var in variables { self.variable_map.get_mut(&var).unwrap().pop(); } } fn declare_variable(&mut self, name: &str, var_type: VariableType, var_data_type: IRType) -> Result { if self.scopes.last().unwrap().contains(name) { return Err(IRError::VariableHasBeenDefined(name.to_string())); } let variable = match var_type { VariableType::Global => { let var = Variable { index: self.global_counter, var_type, data_type: var_data_type }; self.global_counter += 1; var } VariableType::Local => { let var = Variable { index: self.local_counter, var_type, data_type: var_data_type }; self.local_counter += 1; var } _ => unreachable!(), }; self.variable_map.entry(name.to_string()).or_default().push(variable); self.scopes.last_mut().unwrap().insert(name.to_string()); if matches!(var_type, VariableType::Local) { self.local_var_type.push(variable); } Ok(variable) } pub fn declare_gloabal(&mut self, name: &str, var_data_type: IRType) -> Result { self.declare_variable(name, VariableType::Global, var_data_type) } pub fn declare_local(&mut self, name: &str, var_data_type: IRType) -> Result { self.declare_variable(name, VariableType::Local, var_data_type) } pub fn declare_temp(&mut self, var_data_type: IRType) -> Variable { let var = Variable { index: self.local_counter, var_type: VariableType::Temp, data_type: var_data_type }; self.local_counter += 1; self.local_var_type.push(var); var } pub fn declare_param_temp(&mut self, var_data_type: IRType) -> Variable { let var = Variable { index: self.local_counter, var_type: VariableType::ParamTemp, data_type: var_data_type }; self.local_counter += 1; self.local_var_type.push(var); var } pub fn clear_local_counter(&mut self) { self.local_counter = 0; self.local_var_type.clear(); } pub fn get_cur_func_variables(&self) -> Vec { self.local_var_type.iter().cloned().collect() } pub fn get_variable(&self, name: &str) -> Option { self.variable_map.get(name).and_then(|vars| vars.last()).cloned() } } #[cfg(test)] mod tests { use std::io::BufRead; use std::path::Path; use std::fs::File; use crate::ast::graph::AstGraphExt; use std::io::Write; use crate::frontend::lexer::Lexer; use crate::frontend::parser::Parser; use crate::utils::case_list::CaseList; use crate::utils::num_sequence::NumberSequence; pub use super::*; fn test_case(case_str: &str) { let case_sequence = NumberSequence::from_str(case_str).unwrap(); let case_list = CaseList::from_dir(&Path::new("./testcases")).unwrap(); let mut error_case_cnt = 0; for case_no in case_sequence { let case_path = case_list.get_case_path(case_no).unwrap(); println!("{}", case_path.display()); let file = File::open(&case_path).unwrap(); let mut buf_reader = std::io::BufReader::new(file); let mut lexer = Lexer::new(); let mut full_text = String::new(); loop { let mut line = String::new(); let bytes_read = buf_reader.read_line(&mut line).unwrap(); if bytes_read == 0 { break; } full_text.push_str(&line); lexer.parse_next_str(&line); } let (tokens, diagnostics) = lexer.finish(); let mut is_error = false; if !diagnostics.is_empty() { diagnostics.print(&format!("{}", case_path.display()), &full_text); is_error = true; } let mut parser = Parser::new(tokens, diagnostics); let compile_unit = parser.parse(); let case_name = case_list.get_case_name(case_no).unwrap().strip_suffix(".c").unwrap(); if !parser.diagnostics.is_empty() { parser.diagnostics.print(&format!("{}", case_path.display()), &full_text); is_error = true; } let mut generator = Generator::new(); let ir = generator.emit(compile_unit); if !generator.diagnostic.is_empty() { generator.diagnostic.print(&format!("{}", case_path.display()), &full_text); is_error = true; } let mut ir_file = File::create(format!("output/{}.ir", case_name)).unwrap(); for instr in ir { writeln!(ir_file, "{}", instr).unwrap(); } if is_error { error_case_cnt += 1; } } if error_case_cnt > 0 { panic!("Found {} cases with errors", error_case_cnt); } } #[test] fn test_expr() { test_case("0-3,14-25"); // test_case("0-3,14-25"); } }