利用python脚本对文件夹中的大量文件划分训练集train、验证集val和测试集test。source_dir为源文件夹,source_dir目录中可以包含不同种类的文件夹。
import os
import shutil
import random
from pathlib import Pathdef split_dataset(source_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):"""将数据集按照指定比例分割为训练集、验证集和测试集参数:source_dir: 原始数据集目录train_ratio: 训练集比例val_ratio: 验证集比例test_ratio: 测试集比例"""# 确保比例之和为1assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "比例之和必须为1"# 创建目标目录train_dir = os.path.join(os.path.dirname(source_dir), "train")val_dir = os.path.join(os.path.dirname(source_dir), "val")test_dir = os.path.join(os.path.dirname(source_dir), "test")for dir_path in [train_dir, val_dir, test_dir]:if not os.path.exists(dir_path):os.makedirs(dir_path)print(f"创建目录: {dir_path}")# 遍历源目录中的所有文件和子目录for root, dirs, files in os.walk(source_dir):# 跳过空目录if not files:continue# 为当前目录在目标目录中创建相应的子目录结构relative_path = os.path.relpath(root, source_dir)for dir_path in [train_dir, val_dir, test_dir]:target_dir = os.path.join(dir_path, relative_path)if not os.path.exists(target_dir):os.makedirs(target_dir)# 随机打乱文件顺序random.shuffle(files)total_files = len(files)# 计算各集合的文件数量train_count = int(total_files * train_ratio)val_count = int(total_files * val_ratio)# 测试集数量 = 剩余的文件test_count = total_files - train_count - val_count# 分配文件到各个集合train_files = files[:train_count]val_files = files[train_count:train_count + val_count]test_files = files[train_count + val_count:]# 复制文件到相应的目录for file in train_files:src = os.path.join(root, file)dst = os.path.join(train_dir, relative_path, file)shutil.copy2(src, dst)for file in val_files:src = os.path.join(root, file)dst = os.path.join(val_dir, relative_path, file)shutil.copy2(src, dst)for file in test_files:src = os.path.join(root, file)dst = os.path.join(test_dir, relative_path, file)shutil.copy2(src, dst)print(f"处理目录: {relative_path}")print(f" 训练集: {len(train_files)} 个文件")print(f" 验证集: {len(val_files)} 个文件")print(f" 测试集: {len(test_files)} 个文件")print("数据集分割完成!")if __name__ == "__main__":# 设置源数据集目录# 请将此处替换为你的原始数据集目录source_directory = input("请输入原始数据集目录路径: ").strip()# 检查源目录是否存在if not os.path.isdir(source_directory):print(f"错误: 目录 '{source_directory}' 不存在!")else:# 按7:2:1的比例分割数据集split_dataset(source_directory, 0.7, 0.2, 0.1)