Pytorch 数据加载与数据预处理方式

发布时间: 2019-12-31 17:27:07 来源: 互联网 栏目: python 点击:

今天小编就为大家分享一篇Pytorch 数据加载与数据预处理方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

数据加载分为加载torchvision。datasets中的数据集以及加载自己使用的数据集两种情况。

torchvision.datasets中的数据集

torchvision。datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch。utils。data。Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法。

Dataset源码如上,可以看到其中包含了两个没有实现的子方法,之后所有的Dataet类都继承该类,并根据数据情况定制这两个子方法的具体实现。

因此当我们需要加载自己的数据集的时候也可以借鉴这种方法,只需要继承torch.utils.data.Dataset类并重写 init ,len,以及getitem这三个方法即可。这样组着的类可以直接作为参数传入到torch.util.data.DataLoader中去。

以CIFAR10为例 源码:

class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

加载自己的数据集

对于torchvision.datasets中有两个不同的类,分别为DatasetFolder和ImageFolder,ImageFolder是继承自DatasetFolder。

下面我们通过源码来看一看folder文件中DatasetFolder和ImageFolder分别做了些什么

import torch.utils.data as data
from PIL import Image
import os
import os.path


def has_file_allowed_extension(filename, extensions): //检查输入是否是规定的扩展名
  """Checks if a file is an allowed extension.

  Args:
    filename (string): path to a file

  Returns:
    bool: True if the filename ends with a known image extension
  """
  filename_lower = filename.lower()
  return any(filename_lower.endswith(ext) for ext in extensions)


def find_classes(dir):
  classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] //获取root目录下所有的文件夹名称

  classes.sort()
  class_to_idx = {classes[i]: i for i in range(len(classes))} //生成类别名称与类别id的对应Dictionary
  return classes, class_to_idx


def make_dataset(dir, class_to_idx, extensions):
  images = []
  dir = os.path.expanduser(dir)// 将~和~user转化为用户目录,对参数中出现~进行处理
  for target in sorted(os.listdir(dir)):
    d = os.path.join(dir, target)
    if not os.path.isdir(d):
      continue

    for root, _, fnames in sorted(os.walk(d)): //os.work包含三个部分,root代表该目录路径 _代表该路径下的文件夹名称集合,fnames代表该路径下的文件名称集合
      for fname in sorted(fnames):
        if has_file_allowed_extension(fname, extensions):
          path = os.path.join(root, fname)
          item = (path, class_to_idx[target])
          images.append(item)  //生成(训练样本图像目录,训练样本所属类别)的元组

  return images  //返回上述元组的列表


class DatasetFolder(data.Dataset):
  """A generic data loader where the samples are arranged in this way: ::

    root/class_x/xxx.ext
    root/class_x/xxy.ext
    root/class_x/xxz.ext

    root/class_y/123.ext
    root/class_y/nsdf3.ext
    root/class_y/asd932_.ext

  Args:
    root (string): Root directory path.
    loader (callable): A function to load a sample given its path.
    extensions (list[string]): A list of allowed extensions.
    transform (callable, optional): A function/transform that takes in
      a sample and returns a transformed version.
      E.g, ``transforms.RandomCrop`` for images.
    target_transform (callable, optional): A function/transform that takes
      in the target and transforms it.

   Attributes:
    classes (list): List of the class names.
    class_to_idx (dict): Dict with items (class_name, class_index).
    samples (list): List of (sample path, class_index) tuples
  """

  def __init__(self, root, loader, extensions, transform=None, target_transform=None):
    classes, class_to_idx = find_classes(root)
    samples = make_dataset(root, class_to_idx, extensions)
    if len(samples) == 0:
      raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
                "Supported extensions are: " + ",".join(extensions)))

    self.root = root
    self.loader = loader
    self.extensions = extensions

    self.classes = classes
    self.class_to_idx = class_to_idx
    self.samples = samples

    self.transform = transform
    self.target_transform = target_transform

  def __getitem__(self, index):
    """
    根据index获取sample 返回值为(sample,target)元组,同时如果该类输入参数中有transform和target_transform,torchvision.transforms类型的参数时,将获取的元组分别执行transform和target_transform中的数据转换方法。
       Args:
      index (int): Index

    Returns:
      tuple: (sample, target) where target is class_index of the target class.
    """
    path, target = self.samples[index]
    sample = self.loader(path)
    if self.transform is not None:
      sample = self.transform(sample)
    if self.target_transform is not None:
      target = self.target_transform(target)

    return sample, target


  def __len__(self):
    return len(self.samples)

  def __repr__(self): //定义输出对象格式 其中和__str__的区别是__repr__无论是print输出还是直接输出对象自身 都是以定义的格式进行输出,而__str__ 只有在print输出的时候会是以定义的格式进行输出
    fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
    fmt_str += '  Number of datapoints: {}\n'.format(self.__len__())
    fmt_str += '  Root Location: {}\n'.format(self.root)
    tmp = '  Transforms (if any): '
    fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    tmp = '  Target Transforms (if any): '
    fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    return fmt_str



IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']


def pil_loader(path):
  # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
  with open(path, 'rb') as f:
    img = Image.open(f)
    return img.convert('RGB')


def accimage_loader(path):
  import accimage
  try:
    return accimage.Image(path)
  except IOError:
    # Potentially a decoding problem, fall back to PIL.Image
    return pil_loader(path)


def default_loader(path):
  from torchvision import get_image_backend
  if get_image_backend() == 'accimage':
    return accimage_loader(path)
  else:
    return pil_loader(path)


class ImageFolder(DatasetFolder): 
  """A generic data loader where the images are arranged in this way: ::

    root/dog/xxx.png
    root/dog/xxy.png
    root/dog/xxz.png

    root/cat/123.png
    root/cat/nsdf3.png
    root/cat/asd932_.png

  Args:
    root (string): Root directory path.
    transform (callable, optional): A function/transform that takes in an PIL image
      and returns a transformed version. E.g, ``transforms.RandomCrop``
    target_transform (callable, optional): A function/transform that takes in the
      target and transforms it.
    loader (callable, optional): A function to load an image given its path.

   Attributes:
    classes (list): List of the class names.
    class_to_idx (dict): Dict with items (class_name, class_index).
    imgs (list): List of (image path, class_index) tuples
  """
  def __init__(self, root, transform=None, target_transform=None,
         loader=default_loader):
    super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
                     transform=transform,
                     target_transform=target_transform)
    self.imgs = self.samples

如果自己所要加载的数据组织形式如下

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

即不同类别的训练数据分别存储在不同的文件夹中,这些文件夹都在root(即形如 D:/animals 或者 /usr/animals )路径下

class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)


参数如下:

root (string) – Root directory path.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
loader – A function to load an image given its path. 就是上述源码中


__getitem__(index)
Parameters: index (int) – Index
Returns:  (sample, target) where target is class_index of the target class.
Return type:  tuple

可以通过torchvision.datasets.ImageFolder进行加载

img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
                      transform=transforms.Compose([
                        transforms.Scale(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor()])
                      )
print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader))

对于所有的训练样本都在一个文件夹中 同时有一个对应的txt文件每一行分别是对应图像的路径以及其所属的类别,可以参照上述class写出对应的加载类

def default_loader(path):
  return Image.open(path).convert('RGB')


class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader))

DataLoader解析

位于torch.util.data.DataLoader中 源代码

该接口的主要目的是将pytorch中已有的数据接口如torchvision.datasets.ImageFolder,或者自定义的数据读取接口转化按照

batch_size的大小封装为Tensor,即相当于在内置数据接口或者自定义数据接口的基础上增加一维,大小为batch_size的大小,

得到的数据在之后可以通过封装为Variable,作为模型的输出

_ _ init _ _中所需的参数如下

1. dataset torch.utils.data.Dataset类的子类,可以是torchvision.datasets.ImageFolder等内置类,也可是继承了torch.utils.data.Dataset的自定义类
2. batch_size 每一个batch中包含的样本个数,默认是1 
3. shuffle 一般在训练集中采用,默认是false,设置为true则每一个epoch都会将训练样本打乱
4. sampler 训练样本选取策略,和shuffle是互斥的 如果 shuffle为true,该参数一定要为None
5. batch_sampler BatchSampler 一次产生一个 batch 的 indices,和sampler以及shuffle互斥,一般使用默认的即可
  上述Sampler的源代码地址如下[源代码](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py)
6. num_workers 用于数据加载的线程数量 默认为0 即只有主线程用来加载数据
7. collate_fn 用来聚合数据生成mini_batch

使用的时候一般为如下使用方法:

train_data=torch.utils.data.DataLoader(...) 
for i, (input, target) in enumerate(train_data): 
...

循环取DataLoader中的数据会触发类中_ _ iter __方法,查看源代码可知 其中调用的方法为 return _DataLoaderIter(self),因此需要查看 DataLoaderIter 这一内部类

class DataLoaderIter(object):
  "Iterates once over the DataLoader's dataset, as specified by the sampler"

  def __init__(self, loader):
    self.dataset = loader.dataset
    self.collate_fn = loader.collate_fn
    self.batch_sampler = loader.batch_sampler
    self.num_workers = loader.num_workers
    self.pin_memory = loader.pin_memory and torch.cuda.is_available()
    self.timeout = loader.timeout
    self.done_event = threading.Event()

    self.sample_iter = iter(self.batch_sampler)

    if self.num_workers > 0:
      self.worker_init_fn = loader.worker_init_fn
      self.index_queue = multiprocessing.SimpleQueue()
      self.worker_result_queue = multiprocessing.SimpleQueue()
      self.batches_outstanding = 0
      self.worker_pids_set = False
      self.shutdown = False
      self.send_idx = 0
      self.rcvd_idx = 0
      self.reorder_dict = {}

      base_seed = torch.LongTensor(1).random_()[0]
      self.workers = [
        multiprocessing.Process(
          target=_worker_loop,
          args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
             base_seed + i, self.worker_init_fn, i))
        for i in range(self.num_workers)]

      if self.pin_memory or self.timeout > 0:
        self.data_queue = queue.Queue()
        self.worker_manager_thread = threading.Thread(
          target=_worker_manager_loop,
          args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
             torch.cuda.current_device()))
        self.worker_manager_thread.daemon = True
        self.worker_manager_thread.start()
      else:
        self.data_queue = self.worker_result_queue

      for w in self.workers:
        w.daemon = True # ensure that the worker exits on process exit
        w.start()

      _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
      _set_SIGCHLD_handler()
      self.worker_pids_set = True

      # prime the prefetch loop
      for _ in range(2 * self.num_workers):
        self._put_indices()

以上这篇Pytorch 数据加载与数据预处理方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

本文标题: Pytorch 数据加载与数据预处理方式
本文地址: http://www.cooldogg.com/jiaoben/python/296566.html

如果认为本文对您有所帮助请赞助本站

支付宝扫一扫赞助微信扫一扫赞助

  • 支付宝扫一扫赞助
  • 微信扫一扫赞助
  • 支付宝先领红包再赞助
    声明:凡注明"本站原创"的所有文字图片等资料,版权均属编程客栈所有,欢迎转载,但务请注明出处。
    pytorch 数据处理:定义自己的数据集合实例pytorch 归一化与反归一化实例
    Top 极速赛车能不能玩 极速赛车七码 贵州快3代理 极速赛车有规律吗 澳彩网彩票计划群 金砖彩票计划群 极速赛车是国家开的吗 568彩票计划群 彩票高賠率好平台 山东11选5

    Pytorch 数据加载与数据预处理方式_python_脚本中心 - 编程客栈

    Pytorch 数据加载与数据预处理方式

    发布时间: 2019-12-31 17:27:07 来源: 互联网 栏目: python 点击:

    今天小编就为大家分享一篇Pytorch 数据加载与数据预处理方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

    数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况。

    torchvision.datasets中的数据集

    torchvision。datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch。utils。data。Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法。

    Dataset源码如上,可以看到其中包含了两个没有实现的子方法,之后所有的Dataet类都继承该类,并根据数据情况定制这两个子方法的具体实现。

    因此当我们需要加载自己的数据集的时候也可以借鉴这种方法,只需要继承torch.utils.data.Dataset类并重写 init ,len,以及getitem这三个方法即可。这样组着的类可以直接作为参数传入到torch.util.data.DataLoader中去。

    以CIFAR10为例 源码:

    class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
    root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
    train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
    transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
    target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
    download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

    加载自己的数据集

    对于torchvision.datasets中有两个不同的类,分别为DatasetFolder和ImageFolder,ImageFolder是继承自DatasetFolder。

    下面我们通过源码来看一看folder文件中DatasetFolder和ImageFolder分别做了些什么

    import torch.utils.data as data
    from PIL import Image
    import os
    import os.path
    
    
    def has_file_allowed_extension(filename, extensions): //检查输入是否是规定的扩展名
      """Checks if a file is an allowed extension.
    
      Args:
        filename (string): path to a file
    
      Returns:
        bool: True if the filename ends with a known image extension
      """
      filename_lower = filename.lower()
      return any(filename_lower.endswith(ext) for ext in extensions)
    
    
    def find_classes(dir):
      classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] //获取root目录下所有的文件夹名称
    
      classes.sort()
      class_to_idx = {classes[i]: i for i in range(len(classes))} //生成类别名称与类别id的对应Dictionary
      return classes, class_to_idx
    
    
    def make_dataset(dir, class_to_idx, extensions):
      images = []
      dir = os.path.expanduser(dir)// 将~和~user转化为用户目录,对参数中出现~进行处理
      for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
          continue
    
        for root, _, fnames in sorted(os.walk(d)): //os.work包含三个部分,root代表该目录路径 _代表该路径下的文件夹名称集合,fnames代表该路径下的文件名称集合
          for fname in sorted(fnames):
            if has_file_allowed_extension(fname, extensions):
              path = os.path.join(root, fname)
              item = (path, class_to_idx[target])
              images.append(item)  //生成(训练样本图像目录,训练样本所属类别)的元组
    
      return images  //返回上述元组的列表
    
    
    class DatasetFolder(data.Dataset):
      """A generic data loader where the samples are arranged in this way: ::
    
        root/class_x/xxx.ext
        root/class_x/xxy.ext
        root/class_x/xxz.ext
    
        root/class_y/123.ext
        root/class_y/nsdf3.ext
        root/class_y/asd932_.ext
    
      Args:
        root (string): Root directory path.
        loader (callable): A function to load a sample given its path.
        extensions (list[string]): A list of allowed extensions.
        transform (callable, optional): A function/transform that takes in
          a sample and returns a transformed version.
          E.g, ``transforms.RandomCrop`` for images.
        target_transform (callable, optional): A function/transform that takes
          in the target and transforms it.
    
       Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        samples (list): List of (sample path, class_index) tuples
      """
    
      def __init__(self, root, loader, extensions, transform=None, target_transform=None):
        classes, class_to_idx = find_classes(root)
        samples = make_dataset(root, class_to_idx, extensions)
        if len(samples) == 0:
          raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
                    "Supported extensions are: " + ",".join(extensions)))
    
        self.root = root
        self.loader = loader
        self.extensions = extensions
    
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
    
        self.transform = transform
        self.target_transform = target_transform
    
      def __getitem__(self, index):
        """
        根据index获取sample 返回值为(sample,target)元组,同时如果该类输入参数中有transform和target_transform,torchvision.transforms类型的参数时,将获取的元组分别执行transform和target_transform中的数据转换方法。
           Args:
          index (int): Index
    
        Returns:
          tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
          sample = self.transform(sample)
        if self.target_transform is not None:
          target = self.target_transform(target)
    
        return sample, target
    
    
      def __len__(self):
        return len(self.samples)
    
      def __repr__(self): //定义输出对象格式 其中和__str__的区别是__repr__无论是print输出还是直接输出对象自身 都是以定义的格式进行输出,而__str__ 只有在print输出的时候会是以定义的格式进行输出
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '  Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '  Root Location: {}\n'.format(self.root)
        tmp = '  Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '  Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str
    
    
    
    IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
    
    
    def pil_loader(path):
      # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
      with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')
    
    
    def accimage_loader(path):
      import accimage
      try:
        return accimage.Image(path)
      except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)
    
    
    def default_loader(path):
      from torchvision import get_image_backend
      if get_image_backend() == 'accimage':
        return accimage_loader(path)
      else:
        return pil_loader(path)
    
    
    class ImageFolder(DatasetFolder): 
      """A generic data loader where the images are arranged in this way: ::
    
        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png
    
        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png
    
      Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that takes in an PIL image
          and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
          target and transforms it.
        loader (callable, optional): A function to load an image given its path.
    
       Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
      """
      def __init__(self, root, transform=None, target_transform=None,
             loader=default_loader):
        super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
                         transform=transform,
                         target_transform=target_transform)
        self.imgs = self.samples
    

    如果自己所要加载的数据组织形式如下

    root/dog/xxx.png
    root/dog/xxy.png
    root/dog/xxz.png
    root/cat/123.png
    root/cat/nsdf3.png
    root/cat/asd932_.png

    即不同类别的训练数据分别存储在不同的文件夹中,这些文件夹都在root(即形如 D:/animals 或者 /usr/animals )路径下

    class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>)
    
    
    

    参数如下:

    root (string) – Root directory path.
    transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
    target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
    loader – A function to load an image given its path. 就是上述源码中
    
    
    __getitem__(index)
    Parameters: index (int) – Index
    Returns:  (sample, target) where target is class_index of the target class.
    Return type:  tuple
    

    可以通过torchvision.datasets.ImageFolder进行加载

    img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
                          transform=transforms.Compose([
                            transforms.Scale(256),
                            transforms.CenterCrop(224),
                            transforms.ToTensor()])
                          )
    print(len(img_data))
    data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
    print(len(data_loader))

    对于所有的训练样本都在一个文件夹中 同时有一个对应的txt文件每一行分别是对应图像的路径以及其所属的类别,可以参照上述class写出对应的加载类

    def default_loader(path):
      return Image.open(path).convert('RGB')
    
    
    class MyDataset(Dataset):
      def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
          line = line.strip('\n')
          line = line.rstrip()
          words = line.split()
          imgs.append((words[0],int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
    
      def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
          img = self.transform(img)
        return img,label
    
      def __len__(self):
        return len(self.imgs)
    
    train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
    data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
    print(len(data_loader))
    

    DataLoader解析

    位于torch.util.data.DataLoader中 源代码

    该接口的主要目的是将pytorch中已有的数据接口如torchvision。datasets。ImageFolder,或者自定义的数据读取接口转化按照

    batch_size的大小封装为Tensor,即相当于在内置数据接口或者自定义数据接口的基础上增加一维,大小为batch_size的大小,

    得到的数据在之后可以通过封装为Variable,作为模型的输出

    _ _ init _ _中所需的参数如下

    1. dataset torch.utils.data.Dataset类的子类,可以是torchvision.datasets.ImageFolder等内置类,也可是继承了torch.utils.data.Dataset的自定义类
    2. batch_size 每一个batch中包含的样本个数,默认是1 
    3. shuffle 一般在训练集中采用,默认是false,设置为true则每一个epoch都会将训练样本打乱
    4. sampler 训练样本选取策略,和shuffle是互斥的 如果 shuffle为true,该参数一定要为None
    5. batch_sampler BatchSampler 一次产生一个 batch 的 indices,和sampler以及shuffle互斥,一般使用默认的即可
      上述Sampler的源代码地址如下[源代码](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py)
    6. num_workers 用于数据加载的线程数量 默认为0 即只有主线程用来加载数据
    7. collate_fn 用来聚合数据生成mini_batch

    使用的时候一般为如下使用方法:

    train_data=torch.utils.data.DataLoader(...) 
    for i, (input, target) in enumerate(train_data): 
    ...

    循环取DataLoader中的数据会触发类中_ _ iter __方法,查看源代码可知 其中调用的方法为 return _DataLoaderIter(self),因此需要查看 DataLoaderIter 这一内部类

    class DataLoaderIter(object):
      "Iterates once over the DataLoader's dataset, as specified by the sampler"
    
      def __init__(self, loader):
        self.dataset = loader.dataset
        self.collate_fn = loader.collate_fn
        self.batch_sampler = loader.batch_sampler
        self.num_workers = loader.num_workers
        self.pin_memory = loader.pin_memory and torch.cuda.is_available()
        self.timeout = loader.timeout
        self.done_event = threading.Event()
    
        self.sample_iter = iter(self.batch_sampler)
    
        if self.num_workers > 0:
          self.worker_init_fn = loader.worker_init_fn
          self.index_queue = multiprocessing.SimpleQueue()
          self.worker_result_queue = multiprocessing.SimpleQueue()
          self.batches_outstanding = 0
          self.worker_pids_set = False
          self.shutdown = False
          self.send_idx = 0
          self.rcvd_idx = 0
          self.reorder_dict = {}
    
          base_seed = torch.LongTensor(1).random_()[0]
          self.workers = [
            multiprocessing.Process(
              target=_worker_loop,
              args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
                 base_seed + i, self.worker_init_fn, i))
            for i in range(self.num_workers)]
    
          if self.pin_memory or self.timeout > 0:
            self.data_queue = queue.Queue()
            self.worker_manager_thread = threading.Thread(
              target=_worker_manager_loop,
              args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
                 torch.cuda.current_device()))
            self.worker_manager_thread.daemon = True
            self.worker_manager_thread.start()
          else:
            self.data_queue = self.worker_result_queue
    
          for w in self.workers:
            w.daemon = True # ensure that the worker exits on process exit
            w.start()
    
          _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
          _set_SIGCHLD_handler()
          self.worker_pids_set = True
    
          # prime the prefetch loop
          for _ in range(2 * self.num_workers):
            self._put_indices()

    以上这篇Pytorch 数据加载与数据预处理方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

    本文标题: Pytorch 数据加载与数据预处理方式
    本文地址: http://www.cooldogg.com/jiaoben/python/296566.html

    如果认为本文对您有所帮助请赞助本站

    支付宝扫一扫赞助微信扫一扫赞助

  • 支付宝扫一扫赞助
  • 微信扫一扫赞助
  • 支付宝先领红包再赞助
    声明:凡注明"本站原创"的所有文字图片等资料,版权均属编程客栈所有,欢迎转载,但务请注明出处。
    pytorch 数据处理:定义自己的数据集合实例pytorch 归一化与反归一化实例
    Top 极速赛车能不能玩 极速赛车七码 贵州快3代理 极速赛车有规律吗 澳彩网彩票计划群 金砖彩票计划群 极速赛车是国家开的吗 568彩票计划群 彩票高賠率好平台 山东11选5