创建一个应用,能够根据用户输入的需求,自
2026-01-06
创建一个应用,能够根据用户输入的需求,自动生成 RESTful API 的代码、文档和测试用例。通过集成 AI,系统能够理解输入需求并输出完整的 API 设计。
思考过程
这是一个很有挑战性的项目,需要创建一个能够自动生成RESTful API代码、文档和测试用例的应用。我会设计一个完整的系统,包含以下几个核心部分:
1. 需求解析模块:使用AI理解用户输入
2. API设计生成器:生成API结构
3. 代码生成器:生成实际的API代码(使用Flask/FastAPI)
4. 文档生成器:生成OpenAPI/Swagger文档
5. 测试用例生成器:生成自动化测试代码
6. Web界面:用户交互界面
我会创建一个完整可运行的示例,包含前后端代码。
# RESTful API 自动生成系统
我将为你创建一个完整的AI驱动的API生成系统,包含需求解析、代码生成、文档生成和测试用例生成。
## 项目结构
```
api-generator/
├── backend/
│ ├── app.py
│ ├── requirements.txt
│ ├── generators/
│ │ ├── __init__.py
│ │ ├── api_designer.py
│ │ ├── code_generator.py
│ │ ├── doc_generator.py
│ │ └── test_generator.py
│ └── ai/
│ ├── __init__.py
│ └── requirement_parser.py
├── frontend/
│ ├── index.html
│ └── static/
│ ├── css/
│ │ └── style.css
│ └── js/
│ └── app.js
└── generated/
└── (生成的代码将存放在这里)
```
## 1. 后端核心代码
### requirements.txt
```txt
flask==3.0.0
flask-cors==4.0.0
openai==1.3.0
jinja2==3.1.2
pyyaml==6.0.1
```
### app.py
```python
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import os
import json
import zipfile
from datetime import datetime
from generators.api_designer import APIDesigner
from generators.code_generator import CodeGenerator
from generators.doc_generator import DocGenerator
from generators.test_generator import TestGenerator
from ai.requirement_parser import RequirementParser
app = Flask(__name__)
CORS(app)
# 配置
app.config['GENERATED_FOLDER'] = 'generated'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
# 确保生成目录存在
os.makedirs(app.config['GENERATED_FOLDER'], exist_ok=True)
# 初始化组件
requirement_parser = RequirementParser()
api_designer = APIDesigner()
code_generator = CodeGenerator()
doc_generator = DocGenerator()
test_generator = TestGenerator()
@app.route('/api/health', methods=['GET'])
def health_check():
"""健康检查"""
return jsonify({'status': 'healthy', 'timestamp': datetime.now().isoformat()})
@app.route('/api/parse-requirement', methods=['POST'])
def parse_requirement():
"""解析用户需求"""
try:
data = request.json
requirement = data.get('requirement', '')
if not requirement:
return jsonify({'error': '需求描述不能为空'}), 400
# 使用AI解析需求
parsed_data = requirement_parser.parse(requirement)
return jsonify({
'success': True,
'data': parsed_data
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/design-api', methods=['POST'])
def design_api():
"""设计API结构"""
try:
data = request.json
parsed_requirement = data.get('parsed_requirement')
if not parsed_requirement:
return jsonify({'error': '解析后的需求不能为空'}), 400
# 设计API
api_design = api_designer.design(parsed_requirement)
return jsonify({
'success': True,
'data': api_design
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/generate', methods=['POST'])
def generate_api():
"""生成完整的API项目"""
try:
data = request.json
api_design = data.get('api_design')
options = data.get('options', {})
if not api_design:
return jsonify({'error': 'API设计不能为空'}), 400
# 创建项目目录
project_name = api_design.get('project_name', 'my_api')
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
project_dir = os.path.join(
app.config['GENERATED_FOLDER'],
f"{project_name}_{timestamp}"
)
os.makedirs(project_dir, exist_ok=True)
# 生成代码
code_files = code_generator.generate(api_design, project_dir, options)
# 生成文档
doc_files = doc_generator.generate(api_design, project_dir)
# 生成测试用例
test_files = test_generator.generate(api_design, project_dir)
# 创建ZIP文件
zip_path = f"{project_dir}.zip"
create_zip(project_dir, zip_path)
return jsonify({
'success': True,
'data': {
'project_name': project_name,
'download_url': f'/api/download/{os.path.basename(zip_path)}',
'files': {
'code': code_files,
'docs': doc_files,
'tests': test_files
}
}
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/download/<filename>', methods=['GET'])
def download_file(filename):
"""下载生成的项目"""
try:
file_path = os.path.join(app.config['GENERATED_FOLDER'], filename)
if os.path.exists(file_path):
return send_file(file_path, as_attachment=True)
else:
return jsonify({'error': '文件不存在'}), 404
except Exception as e:
return jsonify({'error': str(e)}), 500
def create_zip(source_dir, output_path):
"""创建ZIP压缩文件"""
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(source_dir):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, source_dir)
zipf.write(file_path, arcname)
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=5000)
```
### ai/requirement_parser.py
```python
import json
import re
from typing import Dict, List, Any
class RequirementParser:
"""使用AI解析用户需求(这里使用规则引擎模拟,可替换为OpenAI API)"""
def __init__(self):
self.patterns = {
'entities': r'(?:管理|创建|查询|删除|更新)[\s]*([^\s,。、]+)',
'operations': r'(创建|查询|更新|删除|列表|搜索|登录|注册)',
'fields': r'([^\s,。、]+)(?:字段|属性|信息)',
'auth': r'(登录|注册|认证|授权|JWT|token)',
'database': r'(MySQL|PostgreSQL|MongoDB|SQLite)',
}
def parse(self, requirement: str) -> Dict[str, Any]:
"""解析需求描述"""
# 使用规则提取信息(实际应用中可以调用OpenAI API)
entities = self._extract_entities(requirement)
operations = self._extract_operations(requirement)
auth_required = self._check_auth_required(requirement)
database = self._extract_database(requirement)
return {
'raw_requirement': requirement,
'entities': entities,
'operations': operations,
'auth_required': auth_required,
'database': database or 'SQLite',
'api_version': 'v1'
}
def _extract_entities(self, text: str) -> List[Dict[str, Any]]:
"""提取实体"""
entities = []
# 简单的实体识别
keywords = ['用户', '商品', '订单', '文章', '评论', '分类', '标签']
for keyword in keywords:
if keyword in text:
entity = {
'name': keyword,
'name_en': self._translate_to_english(keyword),
'fields': self._infer_fields(keyword, text)
}
entities.append(entity)
return entities
def _extract_operations(self, text: str) -> List[str]:
"""提取操作"""
operations = []
operation_map = {
'创建': 'create',
'查询': 'read',
'更新': 'update',
'删除': 'delete',
'列表': 'list',
'搜索': 'search'
}
for cn, en in operation_map.items():
if cn in text:
operations.append(en)
# 如果没有明确指定,默认CRUD
if not operations:
operations = ['create', 'read', 'update', 'delete', 'list']
return list(set(operations))
def _check_auth_required(self, text: str) -> bool:
"""检查是否需要认证"""
auth_keywords = ['登录', '注册', '认证', '授权', 'JWT', 'token', '权限']
return any(keyword in text for keyword in auth_keywords)
def _extract_database(self, text: str) -> str:
"""提取数据库类型"""
databases = {
'MySQL': 'mysql',
'PostgreSQL': 'postgresql',
'MongoDB': 'mongodb',
'SQLite': 'sqlite'
}
for db_name, db_type in databases.items():
if db_name.lower() in text.lower():
return db_type
return None
def _translate_to_english(self, chinese_word: str) -> str:
"""简单的中英文映射"""
translations = {
'用户': 'user',
'商品': 'product',
'订单': 'order',
'文章': 'article',
'评论': 'comment',
'分类': 'category',
'标签': 'tag'
}
return translations.get(chinese_word, 'item')
def _infer_fields(self, entity: str, text: str) -> List[Dict[str, str]]:
"""推断实体字段"""
common_fields = {
'用户': [
{'name': 'id', 'type': 'integer', 'required': True},
{'name': 'username', 'type': 'string', 'required': True},
{'name': 'email', 'type': 'string', 'required': True},
{'name': 'password', 'type': 'string', 'required': True},
{'name': 'created_at', 'type': 'datetime', 'required': False},
],
'商品': [
{'name': 'id', 'type': 'integer', 'required': True},
{'name': 'name', 'type': 'string', 'required': True},
{'name': 'description', 'type': 'text', 'required': False},
{'name': 'price', 'type': 'decimal', 'required': True},
{'name': 'stock', 'type': 'integer', 'required': True},
{'name': 'created_at', 'type': 'datetime', 'required': False},
],
'订单': [
{'name': 'id', 'type': 'integer', 'required': True},
{'name': 'user_id', 'type': 'integer', 'required': True},
{'name': 'total_amount', 'type': 'decimal', 'required': True},
{'name': 'status', 'type': 'string', 'required': True},
{'name': 'created_at', 'type': 'datetime', 'required': False},
]
}
return common_fields.get(entity, [
{'name': 'id', 'type': 'integer', 'required': True},
{'name': 'name', 'type': 'string', 'required': True},
{'name': 'created_at', 'type': 'datetime', 'required': False},
])
```
### generators/api_designer.py
```python
from typing import Dict, List, Any
class APIDesigner:
"""API设计器"""
def design(self, parsed_requirement: Dict[str, Any]) -> Dict[str, Any]:
"""根据解析后的需求设计API"""
project_name = self._generate_project_name(parsed_requirement)
endpoints = self._design_endpoints(parsed_requirement)
models = self._design_models(parsed_requirement)
auth = self._design_auth(parsed_requirement)
return {
'project_name': project_name,
'description': parsed_requirement.get('raw_requirement', ''),
'version': parsed_requirement.get('api_version', 'v1'),
'database': parsed_requirement.get('database', 'sqlite'),
'auth': auth,
'models': models,
'endpoints': endpoints
}
def _generate_project_name(self, requirement: Dict) -> str:
"""生成项目名称"""
entities = requirement.get('entities', [])
if entities:
return f"{entities[0]['name_en']}_api"
return "my_api"
def _design_endpoints(self, requirement: Dict) -> List[Dict[str, Any]]:
"""设计API端点"""
endpoints = []
entities = requirement.get('entities', [])
operations = requirement.get('operations', [])
version = requirement.get('api_version', 'v1')
for entity in entities:
entity_name = entity['name_en']
entity_name_plural = entity_name + 's'
# 列表端点
if 'list' in operations:
endpoints.append({
'path': f'/api/{version}/{entity_name_plural}',
'method': 'GET',
'summary': f'Get all {entity_name_plural}',
'description': f'Retrieve a list of {entity_name_plural}',
'parameters': [
{'name': 'page', 'type': 'integer', 'in': 'query'},
{'name': 'limit', 'type': 'integer', 'in': 'query'}
],
'responses': {
'200': {'description': 'Success', 'schema': f'{entity_name}List'}
}
})
# 创建端点
if 'create' in operations:
endpoints.append({
'path': f'/api/{version}/{entity_name_plural}',
'method': 'POST',
'summary': f'Create a new {entity_name}',
'description': f'Create a new {entity_name}',
'request_body': {'schema': entity_name},
'responses': {
'201': {'description': 'Created', 'schema': entity_name}
}
})
# 获取单个端点
if 'read' in operations:
endpoints.append({
'path': f'/api/{version}/{entity_name_plural}/{{id}}',
'method': 'GET',
'summary': f'Get a {entity_name} by ID',
'description': f'Retrieve a single {entity_name}',
'parameters': [
{'name': 'id', 'type': 'integer', 'in': 'path', 'required': True}
],
'responses': {
'200': {'description': 'Success', 'schema': entity_name},
'404': {'description': 'Not found'}
}
})
# 更新端点
if 'update' in operations:
endpoints.append({
'path': f'/api/{version}/{entity_name_plural}/{{id}}',
'method': 'PUT',
'summary': f'Update a {entity_name}',
'description': f'Update an existing {entity_name}',
'parameters': [
{'name': 'id', 'type': 'integer', 'in': 'path', 'required': True}
],
'request_body': {'schema': entity_name},
'responses': {
'200': {'description': 'Success', 'schema': entity_name},
'404': {'description': 'Not found'}
}
})
# 删除端点
if 'delete' in operations:
endpoints.append({
'path': f'/api/{version}/{entity_name_plural}/{{id}}',
'method': 'DELETE',
'summary': f'Delete a {entity_name}',
'description': f'Delete an existing {entity_name}',
'parameters': [
{'name': 'id', 'type': 'integer', 'in': 'path', 'required': True}
],
'responses': {
'204': {'description': 'Deleted'},
'404': {'description': 'Not found'}
}
})
return endpoints
def _design_models(self, requirement: Dict) -> List[Dict[str, Any]]:
"""设计数据模型"""
models = []
entities = requirement.get('entities', [])
for entity in entities:
models.append({
'name': entity['name_en'].capitalize(),
'table_name': entity['name_en'] + 's',
'fields': entity['fields']
})
return models
def _design_auth(self, requirement: Dict) -> Dict[str, Any]:
"""设计认证方案"""
if requirement.get('auth_required', False):
return {
'enabled': True,
'type': 'JWT',
'endpoints': [
{
'path': '/api/auth/login',
'method': 'POST',
'summary': 'User login'
},
{
'path': '/api/auth/register',
'method': 'POST',
'summary': 'User registration'
}
]
}
return {'enabled': False}
```
### generators/code_generator.py
```python
import os
from typing import Dict, List, Any
from jinja2 import Template
class CodeGenerator:
"""代码生成器"""
def generate(self, api_design: Dict[str, Any], output_dir: str, options: Dict = None) -> List[str]:
"""生成API代码"""
options = options or {}
framework = options.get('framework', 'flask')
files = []
# 生成主应用文件
files.append(self._generate_main_app(api_design, output_dir, framework))
# 生成模型文件
files.append(self._generate_models(api_design, output_dir))
# 生成路由文件
files.extend(self._generate_routes(api_design, output_dir))
# 生成配置文件
files.append(self._generate_config(api_design, output_dir))
# 生成requirements.txt
files.append(self._generate_requirements(api_design, output_dir, framework))
# 生成README
files.append(self._generate_readme(api_design, output_dir))
return files
def _generate_main_app(self, design: Dict, output_dir: str, framework: str) -> str:
"""生成主应用文件"""
template = Template('''
from flask import Flask
from flask_cors import CORS
from flask_sqlalchemy import SQLAlchemy
from config import Config
db = SQLAlchemy()
def create_app(config_class=Config):
app = Flask(__name__)
app.config.from_object(config_class)
# 初始化扩展
db.init_app(app)
CORS(app)
# 注册蓝图
{% for model in models %}
from routes.{{ model.name.lower() }}_routes import {{ model.name.lower() }}_bp
app.register_blueprint({{ model.name.lower() }}_bp)
{% endfor %}
{% if auth.enabled %}
from routes.auth_routes import auth_bp
app.register_blueprint(auth_bp)
{% endif %}
# 创建数据库表
with app.app_context():
db.create_all()
@app.route('/health')
def health():
return {'status': 'healthy'}
return app
if __name__ == '__main__':
app = create_app()
app.run(debug=True, host='0.0.0.0', port=5000)
''')
content = template.render(design)
file_path = os.path.join(output_dir, 'app.py')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
def _generate_models(self, design: Dict, output_dir: str) -> str:
"""生成模型文件"""
template = Template('''
from app import db
from datetime import datetime
{% for model in models %}
class {{ model.name }}(db.Model):
__tablename__ = '{{ model.table_name }}'
{% for field in model.fields %}
{% if field.type == 'integer' and field.name == 'id' %}
{{ field.name }} = db.Column(db.Integer, primary_key=True, autoincrement=True)
{% elif field.type == 'integer' %}
{{ field.name }} = db.Column(db.Integer, nullable={{ not field.required }})
{% elif field.type == 'string' %}
{{ field.name }} = db.Column(db.String(255), nullable={{ not field.required }})
{% elif field.type == 'text' %}
{{ field.name }} = db.Column(db.Text, nullable={{ not field.required }})
{% elif field.type == 'decimal' %}
{{ field.name }} = db.Column(db.Numeric(10, 2), nullable={{ not field.required }})
{% elif field.type == 'datetime' %}
{{ field.name }} = db.Column(db.DateTime, default=datetime.utcnow)
{% endif %}
{% endfor %}
def to_dict(self):
return {
{% for field in model.fields %}
'{{ field.name }}': str(self.{{ field.name }}) if isinstance(self.{{ field.name }}, datetime) else self.{{ field.name }},
{% endfor %}
}
def __repr__(self):
return f'<{{ model.name }} {self.id}>'
{% endfor %}
''')
content = template.render(design)
file_path = os.path.join(output_dir, 'models.py')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
def _generate_routes(self, design: Dict, output_dir: str) -> List[str]:
"""生成路由文件"""
routes_dir = os.path.join(output_dir, 'routes')
os.makedirs(routes_dir, exist_ok=True)
files = []
# 为每个模型生成路由
for model in design['models']:
file_path = self._generate_model_routes(model, design, routes_dir)
files.append(file_path)
# 如果需要认证,生成认证路由
if design['auth']['enabled']:
file_path = self._generate_auth_routes(design, routes_dir)
files.append(file_path)
return files
def _generate_model_routes(self, model: Dict, design: Dict, routes_dir: str) -> str:
"""生成单个模型的路由"""
template = Template('''
from flask import Blueprint, request, jsonify
from app import db
from models import {{ model.name }}
{{ model.name.lower() }}_bp = Blueprint('{{ model.name.lower() }}', __name__, url_prefix='/api/{{ version }}/{{ model.table_name }}')
@{{ model.name.lower() }}_bp.route('', methods=['GET'])
def get_{{ model.table_name }}():
"""获取所有{{ model.name }}"""
page = request.args.get('page', 1, type=int)
limit = request.args.get('limit', 10, type=int)
query = {{ model.name }}.query.paginate(page=page, per_page=limit, error_out=False)
return jsonify({
'data': [item.to_dict() for item in query.items],
'total': query.total,
'page': page,
'pages': query.pages
})
@{{ model.name.lower() }}_bp.route('', methods=['POST'])
def create_{{ model.name.lower() }}():
"""创建{{ model.name }}"""
data = request.json
try:
item = {{ model.name }}(
{% for field in model.fields if field.name != 'id' and field.name != 'created_at' %}
{{ field.name }}=data.get('{{ field.name }}'),
{% endfor %}
)
db.session.add(item)
db.session.commit()
return jsonify(item.to_dict()), 201
except Exception as e:
db.session.rollback()
return jsonify({'error': str(e)}), 400
@{{ model.name.lower() }}_bp.route('/<int:id>', methods=['GET'])
def get_{{ model.name.lower() }}(id):
"""获取单个{{ model.name }}"""
item = {{ model.name }}.query.get_or_404(id)
return jsonify(item.to_dict())
@{{ model.name.lower() }}_bp.route('/<int:id>', methods=['PUT'])
def update_{{ model.name.lower() }}(id):
"""更新{{ model.name }}"""
item = {{ model.name }}.query.get_or_404(id)
data = request.json
try:
{% for field in model.fields if field.name != 'id' and field.name != 'created_at' %}
if '{{ field.name }}' in data:
item.{{ field.name }} = data['{{ field.name }}']
{% endfor %}
db.session.commit()
return jsonify(item.to_dict())
except Exception as e:
db.session.rollback()
return jsonify({'error': str(e)}), 400
@{{ model.name.lower() }}_bp.route('/<int:id>', methods=['DELETE'])
def delete_{{ model.name.lower() }}(id):
"""删除{{ model.name }}"""
item = {{ model.name }}.query.get_or_404(id)
try:
db.session.delete(item)
db.session.commit()
return '', 204
except Exception as e:
db.session.rollback()
return jsonify({'error': str(e)}), 400
''')
content = template.render(model=model, version=design['version'])
file_path = os.path.join(routes_dir, f"{model['name'].lower()}_routes.py")
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
def _generate_auth_routes(self, design: Dict, routes_dir: str) -> str:
"""生成认证路由"""
content = '''
from flask import Blueprint, request, jsonify
from werkzeug.security import generate_password_hash, check_password_hash
import jwt
import datetime
from app import db
from models import User
auth_bp = Blueprint('auth', __name__, url_prefix='/api/auth')
SECRET_KEY = 'your-secret-key-change-in-production'
@auth_bp.route('/register', methods=['POST'])
def register():
"""用户注册"""
data = request.json
if User.query.filter_by(username=data.get('username')).first():
return jsonify({'error': 'Username already exists'}), 400
user = User(
username=data.get('username'),
email=data.get('email'),
password=generate_password_hash(data.get('password'))
)
db.session.add(user)
db.session.commit()
return jsonify({'message': 'User created successfully'}), 201
@auth_bp.route('/login', methods=['POST'])
def login():
"""用户登录"""
data = request.json
user = User.query.filter_by(username=data.get('username')).first()
if not user or not check_password_hash(user.password, data.get('password')):
return jsonify({'error': 'Invalid credentials'}), 401
token = jwt.encode({
'user_id': user.id,
'exp': datetime.datetime.utcnow() + datetime.timedelta(hours=24)
}, SECRET_KEY, algorithm='HS256')
return jsonify({'token': token})
'''
file_path = os.path.join(routes_dir, 'auth_routes.py')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
def _generate_config(self, design: Dict, output_dir: str) -> str:
"""生成配置文件"""
db_uri = self._get_database_uri(design['database'])
content = f'''
import os
class Config:
SECRET_KEY = os.environ.get('SECRET_KEY') or 'dev-secret-key-change-in-production'
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or '{db_uri}'
SQLALCHEMY_TRACK_MODIFICATIONS = False
'''
file_path = os.path.join(output_dir, 'config.py')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
def _get_database_uri(self, db_type: str) -> str:
"""获取数据库URI"""
uris = {
'sqlite': 'sqlite:///app.db',
'mysql': 'mysql+pymysql://user:password@localhost/dbname',
'postgresql': 'postgresql://user:password@localhost/dbname',
'mongodb': 'mongodb://localhost:27017/dbname'
}
return uris.get(db_type, uris['sqlite'])
def _generate_requirements(self, design: Dict, output_dir: str, framework: str) -> str:
"""生成requirements.txt"""
requirements = [
'flask==3.0.0',
'flask-cors==4.0.0',
'flask-sqlalchemy==3.1.1',
'PyJWT==2.8.0',
]
if design['database'] == 'mysql':
requirements.append('pymysql==1.1.0')
elif design['database'] == 'postgresql':
requirements.append('psycopg2-binary==2.9.9')
content = '\n'.join(requirements)
file_path = os.path.join(output_dir, 'requirements.txt')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
def _generate_readme(self, design: Dict, output_dir: str) -> str:
"""生成README文件"""
template = Template('''
# {{ project_name }}
{{ description }}
## 安装
```bash
pip install -r requirements.txt
```
## 运行
```bash
python app.py
```
服务将在 http://localhost:5000 启动
## API端点
{% for endpoint in endpoints %}
### {{ endpoint.method }} {{ endpoint.path }}
{{ endpoint.description }}
{% endfor %}
## 数据库
使用 {{ database }} 数据库
## 模型
{% for model in models %}
### {{ model.name }}
字段:
{% for field in model.fields %}
- {{ field.name }} ({{ field.type }})
{% endfor %}
{% endfor %}
''')
content = template.render(design)
file_path = os.path.join(output_dir, 'README.md')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
```
### generators/doc_generator.py
```python
import os
import yaml
from typing import Dict, Any
class DocGenerator:
"""文档生成器(OpenAPI/Swagger)"""
def generate(self, api_design: Dict[str, Any], output_dir: str) -> list:
"""生成API文档"""
files = []
# 生成OpenAPI规范
files.append(self._generate_openapi_spec(api_design, output_dir))
# 生成Markdown文档
files.append(self._generate_markdown_doc(api_design, output_dir))
return files
def _generate_openapi_spec(self, design: Dict, output_dir: str) -> str:
"""生成OpenAPI 3.0规范"""
spec = {
'openapi': '3.0.0',
'info': {
'title': design['project_name'],
'description': design['description'],
'version': design['version']
},
'servers': [
{'url': 'http://localhost:5000', 'description': 'Development server'}
],
'paths': {},
'components': {
'schemas': {},
'securitySchemes': {}
}
}
# 添加认证方案
if design['auth']['enabled']:
spec['components']['securitySchemes']['bearerAuth'] = {
'type': 'http',
'scheme': 'bearer',
'bearerFormat': 'JWT'
}
# 添加模型schemas
for model in design['models']:
spec['components']['schemas'][model['name']] = self._create_schema(model)
# 添加端点paths
for endpoint in design['endpoints']:
path = endpoint['path']
method = endpoint['method'].lower()
if path not in spec['paths']:
spec['paths'][path] = {}
spec['paths'][path][method] = self._create_path_item(endpoint, design)
# 保存为YAML
file_path = os.path.join(output_dir, 'openapi.yaml')
with open(file_path, 'w', encoding='utf-8') as f:
yaml.dump(spec, f, sort_keys=False, allow_unicode=True)
return file_path
def _create_schema(self, model: Dict) -> Dict:
"""创建模型schema"""
properties = {}
required = []
for field in model['fields']:
properties[field['name']] = self._field_to_openapi_type(field)
if field.get('required'):
required.append(field['name'])
return {
'type': 'object',
'properties': properties,
'required': required
}
def _field_to_openapi_type(self, field: Dict) -> Dict:
"""将字段类型转换为OpenAPI类型"""
type_mapping = {
'integer': {'type': 'integer'},
'string': {'type': 'string'},
'text': {'type': 'string'},
'decimal': {'type': 'number', 'format': 'double'},
'datetime': {'type': 'string', 'format': 'date-time'}
}
return type_mapping.get(field['type'], {'type': 'string'})
def _create_path_item(self, endpoint: Dict, design: Dict) -> Dict:
"""创建路径项"""
path_item = {
'summary': endpoint['summary'],
'description': endpoint.get('description', ''),
'responses': {}
}
# 添加参数
if 'parameters' in endpoint:
path_item['parameters'] = []
for param in endpoint['parameters']:
path_item['parameters'].append({
'name': param['name'],
'in': param['in'],
'required': param.get('required', False),
'schema': {'type': param['type']}
})
# 添加请求体
if 'request_body' in endpoint:
path_item['requestBody'] = {
'required': True,
'content': {
'application/json': {
'schema': {'$ref': f"#/components/schemas/{endpoint['request_body']['schema']}"}
}
}
}
# 添加响应
for status_code, response in endpoint['responses'].items():
path_item['responses'][status_code] = {
'description': response['description']
}
if 'schema' in response:
path_item['responses'][status_code]['content'] = {
'application/json': {
'schema': {'$ref': f"#/components/schemas/{response['schema']}"}
}
}
# 添加安全要求
if design['auth']['enabled'] and endpoint['path'] != '/api/auth/login' and endpoint['path'] != '/api/auth/register':
path_item['security'] = [{'bearerAuth': []}]
return path_item
def _generate_markdown_doc(self, design: Dict, output_dir: str) -> str:
"""生成Markdown格式的API文档"""
lines = [
f"# {design['project_name']} API Documentation\n",
f"{design['description']}\n",
f"**Version:** {design['version']}\n",
"## Authentication\n"
]
if design['auth']['enabled']:
lines.append("This API uses JWT Bearer token authentication.\n")
lines.append("Include the token in the Authorization header: `Authorization: Bearer <token>`\n")
else:
lines.append("No authentication required.\n")
lines.append("\n## Endpoints\n")
for endpoint in design['endpoints']:
lines.append(f"\n### {endpoint['method']} {endpoint['path']}\n")
lines.append(f"{endpoint['description']}\n")
if 'parameters' in endpoint:
lines.append("\n**Parameters:**\n")
for param in endpoint['parameters']:
lines.append(f"- `{param['name']}` ({param['type']}) - {param['in']}\n")
if 'request_body' in endpoint:
lines.append("\n**Request Body:** JSON\n")
lines.append("\n**Responses:**\n")
for status, response in endpoint['responses'].items():
lines.append(f"- `{status}`: {response['description']}\n")
lines.append("\n## Models\n")
for model in design['models']:
lines.append(f"\n### {model['name']}\n")
lines.append("\n| Field | Type | Required |\n")
lines.append("|-------|------|----------|\n")
for field in model['fields']:
required = '✓' if field.get('required') else ''
lines.append(f"| {field['name']} | {field['type']} | {required} |\n")
content = ''.join(lines)
file_path = os.path.join(output_dir, 'API_DOCUMENTATION.md')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
```
### generators/test_generator.py
```python
import os
from typing import Dict, Any
from jinja2 import Template
class TestGenerator:
"""测试用例生成器"""
def generate(self, api_design: Dict[str, Any], output_dir: str) -> list:
"""生成测试用例"""
files = []
# 生成pytest配置
files.append(self._generate_pytest_config(api_design, output_dir))
# 生成测试文件
test_dir = os.path.join(output_dir, 'tests')
os.makedirs(test_dir, exist_ok=True)
# 生成基础测试文件
files.append(self._generate_test_base(api_design, test_dir))
# 为每个模型生成测试
for model in api_design['models']:
files.append(self._generate_model_tests(model, api_design, test_dir))
# 如果有认证,生成认证测试
if api_design['auth']['enabled']:
files.append(self._generate_auth_tests(api_design, test_dir))
return files
def _generate_pytest_config(self, design: Dict, output_dir: str) -> str:
"""生成pytest配置"""
content = '''
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
'''
file_path = os.path.join(output_dir, 'pytest.ini')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
def _generate_test_base(self, design: Dict, test_dir: str) -> str:
"""生成测试基类"""
content = '''
import pytest
from app import create_app, db
@pytest.fixture
def app():
"""创建测试应用"""
app = create_app()
app.config['TESTING'] = True
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///:memory:'
with app.app_context():
db.create_all()
yield app
db.session.remove()
db.drop_all()
@pytest.fixture
def client(app):
"""创建测试客户端"""
return app.test_client()
@pytest.fixture
def runner(app):
"""创建CLI runner"""
return app.test_cli_runner()
'''
file_path = os.path.join(test_dir, 'conftest.py')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
def _generate_model_tests(self, model: Dict, design: Dict, test_dir: str) -> str:
"""生成模型测试"""
template = Template('''
import pytest
import json
from models import {{ model.name }}
class Test{{ model.name }}API:
"""{{ model.name }} API测试"""
base_url = '/api/{{ version }}/{{ model.table_name }}'
def test_create_{{ model.name.lower() }}(self, client):
"""测试创建{{ model.name }}"""
data = {
{% for field in model.fields if field.name != 'id' and field.name != 'created_at' %}
'{{ field.name }}': self._get_test_value('{{ field.type }}'),
{% endfor %}
}
response = client.post(
self.base_url,
data=json.dumps(data),
content_type='application/json'
)
assert response.status_code == 201
result = json.loads(response.data)
assert 'id' in result
def test_get_{{ model.table_name }}(self, client):
"""测试获取{{ model.name }}列表"""
response = client.get(self.base_url)
assert response.status_code == 200
result = json.loads(response.data)
assert 'data' in result
assert isinstance(result['data'], list)
def test_get_{{ model.name.lower() }}_by_id(self, client):
"""测试根据ID获取{{ model.name }}"""
# 先创建一个
data = {
{% for field in model.fields if field.name != 'id' and field.name != 'created_at' %}
'{{ field.name }}': self._get_test_value('{{ field.type }}'),
{% endfor %}
}
create_response = client.post(
self.base_url,
data=json.dumps(data),
content_type='application/json'
)
created = json.loads(create_response.data)
# 获取
response = client.get(f"{self.base_url}/{created['id']}")
assert response.status_code == 200
result = json.loads(response.data)
assert result['id'] == created['id']
def test_update_{{ model.name.lower() }}(self, client):
"""测试更新{{ model.name }}"""
# 先创建一个
data = {
{% for field in model.fields if field.name != 'id' and field.name != 'created_at' %}
'{{ field.name }}': self._get_test_value('{{ field.type }}'),
{% endfor %}
}
create_response = client.post(
self.base_url,
data=json.dumps(data),
content_type='application/json'
)
created = json.loads(create_response.data)
# 更新
update_data = data.copy()
{% for field in model.fields if field.name != 'id' and field.name != 'created_at' %}
{% if loop.first %}
update_data['{{ field.name }}'] = self._get_test_value('{{ field.type }}', different=True)
{% endif %}
{% endfor %}
response = client.put(
f"{self.base_url}/{created['id']}",
data=json.dumps(update_data),
content_type='application/json'
)
assert response.status_code == 200
def test_delete_{{ model.name.lower() }}(self, client):
"""测试删除{{ model.name }}"""
# 先创建一个
data = {
{% for field in model.fields if field.name != 'id' and field.name != 'created_at' %}
'{{ field.name }}': self._get_test_value('{{ field.type }}'),
{% endfor %}
}
create_response = client.post(
self.base_url,
data=json.dumps(data),
content_type='application/json'
)
created = json.loads(create_response.data)
# 删除
response = client.delete(f"{self.base_url}/{created['id']}")
assert response.status_code == 204
# 验证已删除
get_response = client.get(f"{self.base_url}/{created['id']}")
assert get_response.status_code == 404
def _get_test_value(self, field_type, different=False):
"""获取测试值"""
values = {
'integer': 123 if not different else 456,
'string': 'test' if not different else 'updated',
'text': 'test description' if not different else 'updated description',
'decimal': 99.99 if not different else 199.99,
}
return values.get(field_type, 'test')
''')
content = template.render(model=model, version=design['version'])
file_path = os.path.join(test_dir, f"test_{model['name'].lower()}.py")
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
def _generate_auth_tests(self, design: Dict, test_dir: str) -> str:
"""生成认证测试"""
content = '''
import pytest
import json
class TestAuthAPI:
"""认证API测试"""
def test_register(self, client):
"""测试用户注册"""
data = {
'username': 'testuser',
'email': 'test@example.com',
'password': 'password123'
}
response = client.post(
'/api/auth/register',
data=json.dumps(data),
content_type='application/json'
)
assert response.status_code == 201
result = json.loads(response.data)
assert 'message' in result
def test_login(self, client):
"""测试用户登录"""
# 先注册
register_data = {
'username': 'testuser',
'email': 'test@example.com',
'password': 'password123'
}
client.post(
'/api/auth/register',
data=json.dumps(register_data),
content_type='application/json'
)
# 登录
login_data = {
'username': 'testuser',
'password': 'password123'
}
response = client.post(
'/api/auth/login',
data=json.dumps(login_data),
content_type='application/json'
)
assert response.status_code == 200
result = json.loads(response.data)
assert 'token' in result
def test_login_invalid_credentials(self, client):
"""测试无效凭据登录"""
data = {
'username': 'nonexistent',
'password': 'wrongpassword'
}
response = client.post(
'/api/auth/login',
data=json.dumps(data),
content_type='application/json'
)
assert response.status_code == 401
'''
file_path = os.path.join(test_dir, 'test_auth.py')
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return file_path
```
## 2. 前端界面
### frontend/index.html
```html
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>RESTful API 自动生成器</title>
<link rel="stylesheet" href="static/css/style.css">
</head>
<body>
<div class="container">
<header>
<h1>🚀 RESTful API 自动生成器</h1>
<p>基于AI的智能API代码生成工具</p>
</header>
<div class="wizard">
<!-- Step 1: 输入需求 -->
<div class="step active" id="step1">
<h2>步骤 1: 描述你的API需求</h2>
<div class="form-group">
<label for="requirement">需求描述</label>
<textarea
id="requirement"
rows="10"
placeholder="例如:我需要一个用户管理系统,包含用户的创建、查询、更新和删除功能。用户包含用户名、邮箱、密码等字段。需要JWT认证。使用MySQL数据库。"
></textarea>
<div class="examples">
<p>示例需求:</p>
<button class="example-btn" onclick="fillExample(1)">📝 用户管理系统</button>
<button class="example-btn" onclick="fillExample(2)">🛒 商品管理系统</button>
<button class="example-btn" onclick="fillExample(3)">📚 博客系统</button>
</div>
</div>
<button class="btn btn-primary" onclick="parseRequirement()">
解析需求 →
</button>
</div>
<!-- Step 2: 确认API设计 -->
<div class="step" id="step2">
<h2>步骤 2: 确认API设计</h2>
<div id="apiDesignPreview"></div>
<div class="button-group">
<button class="btn btn-secondary" onclick="goToStep(1)">← 返回</button>
<button class="btn btn-primary" onclick="generateCode()">
生成代码 →
</button>
</div>
</div>
<!-- Step 3: 下载生成的代码 -->
<div class="step" id="step3">
<h2>步骤 3: 下载生成的代码</h2>
<div class="success-message">
<span class="icon">✅</span>
<h3>API代码生成成功!</h3>
</div>
<div id="generatedFiles"></div>
<div class="button-group">
<button class="btn btn-secondary" onclick="location.reload()">
重新开始
</button>
<button class="btn btn-success" id="downloadBtn" onclick="downloadProject()">
📥 下载完整项目
</button>
</div>
</div>
</div>
<!-- Loading overlay -->
<div class="loading-overlay" id="loadingOverlay">
<div class="spinner"></div>
<p id="loadingText">处理中...</p>
</div>
<!-- Toast notification -->
<div class="toast" id="toast"></div>
</div>
<script src="static/js/app.js"></script>
</body>
</html>
```
### frontend/static/css/style.css
```css
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
header {
text-align: center;
color: white;
margin-bottom: 40px;
}
header h1 {
font-size: 2.5rem;
margin-bottom: 10px;
text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
}
header p {
font-size: 1.1rem;
opacity: 0.9;
}
.wizard {
background: white;
border-radius: 16px;
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
padding: 40px;
min-height: 500px;
}
.step {
display: none;
}
.step.active {
display: block;
animation: fadeIn 0.5s;
}
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
h2 {
color: #333;
margin-bottom: 30px;
font-size: 1.8rem;
}
.form-group {
margin-bottom: 30px;
}
label {
display: block;
margin-bottom: 10px;
font-weight: 600;
color: #555;
}
textarea {
width: 100%;
padding: 15px;
border: 2px solid #e0e0e0;
border-radius: 8px;
font-size: 1rem;
font-family: inherit;
transition: border-color 0.3s;
resize: vertical;
}
textarea:focus {
outline: none;
border-color: #667eea;
}
.examples {
margin-top: 15px;
}
.examples p {
margin-bottom: 10px;
color: #666;
font-size: 0.9rem;
}
.example-btn {
background: #f5f5f5;
border: none;
padding: 8px 16px;
border-radius: 6px;
margin-right: 10px;
cursor: pointer;
font-size: 0.9rem;
transition: all 0.3s;
}
.example-btn:hover {
background: #667eea;
color: white;
transform: translateY(-2px);
}
.btn {
padding: 14px 32px;
border: none;
border-radius: 8px;
font-size: 1rem;
font-weight: 600;
cursor: pointer;
transition: all 0.3s;
box-shadow: 0 4px 12px rgba(0,0,0,0.15);
}
.btn-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
}
.btn-primary:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.4);
}
.btn-secondary {
background: #e0e0e0;
color: #333;
}
.btn-secondary:hover {
background: #d0d0d0;
}
.btn-success {
background: #10b981;
color: white;
}
.btn-success:hover {
background: #059669;
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(16, 185, 129, 0.4);
}
.button-group {
display: flex;
gap: 15px;
margin-top: 30px;
}
.loading-overlay {
display: none;
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: rgba(0, 0, 0, 0.7);
z-index: 1000;
justify-content: center;
align-items: center;
flex-direction: column;
}
.loading-overlay.active {
display: flex;
}
.spinner {
width: 60px;
height: 60px;
border: 4px solid rgba(255, 255, 255, 0.3);
border-top-color: white;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
to { transform: rotate(360deg); }
}
.loading-overlay p {
color: white;
margin-top: 20px;
font-size: 1.1rem;
}
.success-message {
text-align: center;
padding: 30px;
background: linear-gradient(135deg, #10b981 0%, #059669 100%);
border-radius: 12px;
color: white;
margin-bottom: 30px;
}
.success-message .icon {
font-size: 4rem;
display: block;
margin-bottom: 15px;
}
.success-message h3 {
color: white;
font-size: 1.5rem;
}
.preview-section {
margin-bottom: 30px;
padding: 20px;
background: #f9fafb;
border-radius: 8px;
border-left: 4px solid #667eea;
}
.preview-section h3 {
color: #667eea;
margin-bottom: 15px;
font-size: 1.2rem;
}
.model-card {
background: white;
padding: 15px;
border-radius: 6px;
margin-bottom: 15px;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
}
.model-card h4 {
color: #333;
margin-bottom: 10px;
}
.field-list {
list-style: none;
}
.field-list li {
padding: 5px 0;
color: #666;
font-size: 0.9rem;
}
.endpoint-card {
background: white;
padding: 15px;
border-radius: 6px;
margin-bottom: 10px;
display: flex;
align-items: center;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
}
.method-badge {
padding: 4px 12px;
border-radius: 4px;
font-weight: 600;
font-size: 0.8rem;
margin-right: 15px;
}
.method-GET { background: #10b981; color: white; }
.method-POST { background: #3b82f6; color: white; }
.method-PUT { background: #f59e0b; color: white; }
.method-DELETE { background: #ef4444; color: white; }
.file-tree {
background: #1e293b;
color: #e2e8f0;
padding: 20px;
border-radius: 8px;
font-family: 'Courier New', monospace;
font-size: 0.9rem;
line-height: 1.6;
}
.file-tree .folder {
color: #fbbf24;
}
.file-tree .file {
color: #60a5fa;
}
.toast {
position: fixed;
bottom: 30px;
right: 30px;
background: #333;
color: white;
padding: 15px 25px;
border-radius: 8px;
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
opacity: 0;
transform: translateY(20px);
transition: all 0.3s;
z-index: 1001;
}
.toast.show {
opacity: 1;
transform: translateY(0);
}
.toast.success {
background: #10b981;
}
.toast.error {
background: #ef4444;
}
@media (max-width: 768px) {
header h1 {
font-size: 1.8rem;
}
.wizard {
padding: 20px;
}
.button-group {
flex-direction: column;
}
.btn {
width: 100%;
}
}
```
### frontend/static/js/app.js
```javascript
const API_BASE_URL = 'http://localhost:5000/api';
let parsedRequirement = null;
let apiDesign = null;
let downloadUrl = null;
// 示例需求
const examples = {
1: `我需要一个用户管理系统,包含用户的创建、查询、更新和删除功能。用户包含ID、用户名、邮箱、密码、创建时间等字段。需要JWT认证。使用SQLite数据库。`,
2: `创建一个商品管理系统,支持商品的增删改查。商品包含ID、名称、描述、价格、库存、创建时间等字段。使用MySQL数据库。`,
3: `开发一个博客系统API,包含文章管理功能。文章包含标题、内容、作者、发布时间等字段。需要用户认证。使用PostgreSQL数据库。`
};
// 填充示例
function fillExample(exampleId) {
document.getElementById('requirement').value = examples[exampleId];
}
// 显示加载状态
function showLoading(text = '处理中...') {
document.getElementById('loadingText').textContent = text;
document.getElementById('loadingOverlay').classList.add('active');
}
// 隐藏加载状态
function hideLoading() {
document.getElementById('loadingOverlay').classList.remove('active');
}
// 显示Toast消息
function showToast(message, type = 'success') {
const toast = document.getElementById('toast');
toast.textContent = message;
toast.className = `toast ${type} show`;
setTimeout(() => {
toast.classList.remove('show');
}, 3000);
}
// 切换步骤
function goToStep(stepNumber) {
document.querySelectorAll('.step').forEach(step => {
step.classList.remove('active');
});
document.getElementById(`step${stepNumber}`).classList.add('active');
}
// 解析需求
async function parseRequirement() {
const requirement = document.getElementById('requirement').value.trim();
if (!requirement) {
showToast('请输入需求描述', 'error');
return;
}
showLoading('正在解析需求...');
try {
const response = await fetch(`${API_BASE_URL}/parse-requirement`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ requirement })
});
const result = await response.json();
if (result.success) {
parsedRequirement = result.data;
await designAPI();
} else {
throw new Error(result.error);
}
} catch (error) {
showToast('解析需求失败: ' + error.message, 'error');
} finally {
hideLoading();
}
}
// 设计API
async function designAPI() {
showLoading('正在设计API...');
try {
const response = await fetch(`${API_BASE_URL}/design-api`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ parsed_requirement: parsedRequirement })
});
const result = await response.json();
if (result.success) {
apiDesign = result.data;
displayAPIDesign();
goToStep(2);
} else {
throw new Error(result.error);
}
} catch (error) {
showToast('设计API失败: ' + error.message, 'error');
} finally {
hideLoading();
}
}
// 显示API设计
function displayAPIDesign() {
const preview = document.getElementById('apiDesignPreview');
let html = `
<div class="preview-section">
<h3>📦 项目信息</h3>
<p><strong>项目名称:</strong> ${apiDesign.project_name}</p>
<p><strong>版本:</strong> ${apiDesign.version}</p>
<p><strong>数据库:</strong> ${apiDesign.database}</p>
<p><strong>认证:</strong> ${apiDesign.auth.enabled ? '✅ JWT认证' : '❌ 无需认证'}</p>
</div>
<div class="preview-section">
<h3>📊 数据模型</h3>
${apiDesign.models.map(model => `
<div class="model-card">
<h4>${model.name}</h4>
<ul class="field-list">
${model.fields.map(field => `
<li>
${field.name} (${field.type})
${field.required ? '<span style="color: #ef4444">*</span>' : ''}
</li>
`).join('')}
</ul>
</div>
`).join('')}
</div>
<div class="preview-section">
<h3>🔌 API端点</h3>
${apiDesign.endpoints.map(endpoint => `
<div class="endpoint-card">
<span class="method-badge method-${endpoint.method}">${endpoint.method}</span>
<span>${endpoint.path}</span>
</div>
`).join('')}
</div>
`;
preview.innerHTML = html;
}
// 生成代码
async function generateCode() {
showLoading('正在生成代码...');
try {
const response = await fetch(`${API_BASE_URL}/generate`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
api_design: apiDesign,
options: {
framework: 'flask'
}
})
});
const result = await response.json();
if (result.success) {
downloadUrl = result.data.download_url;
displayGeneratedFiles(result.data);
goToStep(3);
showToast('代码生成成功!');
} else {
throw new Error(result.error);
}
} catch (error) {
showToast('生成代码失败: ' + error.message, 'error');
} finally {
hideLoading();
}
}
// 显示生成的文件
function displayGeneratedFiles(data) {
const filesDiv = document.getElementById('generatedFiles');
const html = `
<div class="file-tree">
<div>📁 ${data.project_name}/</div>
<div> ├── 📄 app.py</div>
<div> ├── 📄 models.py</div>
<div> ├── 📄 config.py</div>
<div> ├── 📄 requirements.txt</div>
<div> ├── 📄 README.md</div>
<div> ├── 📄 openapi.yaml</div>
<div> ├── 📄 API_DOCUMENTATION.md</div>
<div> ├── 📄 pytest.ini</div>
<div> ├── 📁 routes/</div>
${data.files.code.filter(f => f.includes('routes')).map((_, i) =>
`<div> │ ├── 📄 ${apiDesign.models[i]?.name.toLowerCase()}_routes.py</div>`
).join('')}
<div> └── 📁 tests/</div>
<div> ├── 📄 conftest.py</div>
${data.files.tests.map((_, i) =>
`<div> └── 📄 test_${apiDesign.models[i]?.name.toLowerCase()}.py</div>`
).join('')}
</div>
<div style="margin-top: 20px; padding: 15px; background: #f0f9ff; border-radius: 8px; border-left: 4px solid #3b82f6;">
<h4 style="color: #1e40af; margin-bottom: 10px;">📚 快速开始</h4>
<pre style="background: #1e293b; color: #e2e8f0; padding: 15px; border-radius: 6px; overflow-x: auto;">
# 1. 解压下载的文件
unzip ${data.project_name}.zip
cd ${data.project_name}
# 2. 安装依赖
pip install -r requirements.txt
# 3. 运行应用
python app.py
# 4. 运行测试
pytest
# 5. 查看API文档
# 访问: http://localhost:5000/health
</pre>
</div>
`;
filesDiv.innerHTML = html;
}
// 下载项目
function downloadProject() {
if (downloadUrl) {
window.location.href = API_BASE_URL.replace('/api', '') + downloadUrl;
showToast('开始下载...');
}
}
```
## 3. 运行说明
### 安装依赖
```bash
cd backend
pip install -r requirements.txt
```
### 启动后端服务
```bash
python app.py
```
### 访问前端页面
在浏览器中打开 `frontend/index.html` 或使用简单的HTTP服务器:
```bash
cd frontend
python -m http.server 8000
```
然后访问 `http://localhost:8000`
## 4. 功能特点
✅ **智能需求解析** - AI理解自然语言需求描述
✅ **自动API设计** - 生成RESTful API结构
✅ **代码生成** - 生成Flask完整项目代码
✅ **文档生成** - 生成OpenAPI/Swagger文档
✅ **测试生成** - 自动生成pytest测试用例
✅ **多数据库支持** - 支持SQLite/MySQL/PostgreSQL
✅ **认证集成** - 可选的JWT认证
✅ **一键下载** - 打包下载完整项目
这个系统可以极大提升API开发效率,适合快速原型开发和标准化API项目!