import re from pathlib import Path import torch import numpy as np import torch.nn as nn import math from torch.utils.data import Dataset from datamaestro import prepare_dataset import logging logging.basicConfig(level=logging.INFO) # Classes pour la lecture des données class FolderText(Dataset): def __init__(self, classes, folder: Path, tokenizer, load=False): self.tokenizer = tokenizer self.files = [] self.filelabels = [] self.labels = {} for ix, key in enumerate(classes): self.labels[key] = ix for label in classes: for file in (folder / label).glob("*.txt"): self.files.append(file.read_text() if load else file) self.filelabels.append(self.labels[label]) def __len__(self): return len(self.filelabels) def __getitem__(self, ix): s = self.files[ix] return self.tokenizer(s if isinstance(s, str) else s.read_text()), self.filelabels[ix] def get_imdb_data(embedding_size=50): WORDS = re.compile(r"\S+") word2id, embeddings = prepare_dataset('edu.stanford.glove.6b.%d' % embedding_size).load() OOVID = len(word2id) word2id["__OOV__"] = OOVID embeddings = np.vstack((embeddings, np.zeros(embedding_size))) def tokenizer(t): return [word2id.get(x, OOVID) for x in re.findall(WORDS, t.lower())] logging.info("Loading embeddings") logging.info("Get the IMDB dataset") ds = prepare_dataset("edu.stanford.aclimdb") return word2id, embeddings, FolderText(ds.train.classes, ds.train.path, tokenizer, load=False), FolderText(ds.test.classes, ds.test.path, tokenizer, load=False) # Exo 2 class PositionalEncoding(nn.Module): "Position embeddings" def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model, dtype=torch.float) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) pe.requires_grad = False self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:, :x.size(1)] return x