Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Parse] Allow #if to guard switch case clauses #9457

Merged
merged 2 commits into from Jun 28, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Next Next commit
[Parse] Allow #if to guard switch case clauses
  • Loading branch information
rintaro committed Jun 17, 2017
commit 5d478bdb3b7638f5df6f0e1f4e574bececae9b80
26 changes: 22 additions & 4 deletions include/swift/AST/Stmt.h
Expand Up @@ -932,7 +932,7 @@ class CaseStmt final : public Stmt,

/// Switch statement.
class SwitchStmt final : public LabeledStmt,
private llvm::TrailingObjects<SwitchStmt, CaseStmt *> {
private llvm::TrailingObjects<SwitchStmt, ASTNode> {
friend TrailingObjects;

SourceLoc SwitchLoc, LBraceLoc, RBraceLoc;
Expand All @@ -953,7 +953,7 @@ class SwitchStmt final : public LabeledStmt,
static SwitchStmt *create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,
Expr *SubjectExpr,
SourceLoc LBraceLoc,
ArrayRef<CaseStmt*> Cases,
ArrayRef<ASTNode> Cases,
SourceLoc RBraceLoc,
ASTContext &C);

Expand All @@ -972,10 +972,28 @@ class SwitchStmt final : public LabeledStmt,
/// Get the subject expression of the switch.
Expr *getSubjectExpr() const { return SubjectExpr; }
void setSubjectExpr(Expr *e) { SubjectExpr = e; }

ArrayRef<ASTNode> getRawCases() const {
return {getTrailingObjects<ASTNode>(), CaseCount};
}

private:
struct AsCaseStmtWithSkippingIfConfig {
AsCaseStmtWithSkippingIfConfig() {}
Optional<CaseStmt*> operator()(const ASTNode &N) const {
if (auto *CS = llvm::dyn_cast_or_null<CaseStmt>(N.dyn_cast<Stmt*>()))
return CS;
return None;
}
};

public:
using AsCaseStmtRange = OptionalTransformRange<ArrayRef<ASTNode>,
AsCaseStmtWithSkippingIfConfig>;

/// Get the list of case clauses.
ArrayRef<CaseStmt*> getCases() const {
return {getTrailingObjects<CaseStmt*>(), CaseCount};
AsCaseStmtRange getCases() const {
return AsCaseStmtRange(getRawCases(), AsCaseStmtWithSkippingIfConfig());
}

static bool classof(const Stmt *S) {
Expand Down
5 changes: 4 additions & 1 deletion include/swift/Parse/Parser.h
Expand Up @@ -1269,6 +1269,8 @@ class Parser {
// Statement Parsing

bool isStartOfStmt();
bool isTerminatorForBraceItemListKind(BraceItemListKind Kind,
ArrayRef<ASTNode> ParsedDecls);
ParserResult<Stmt> parseStmt();
ParserStatus parseExprOrStmt(ASTNode &Result);
ParserResult<Stmt> parseStmtBreak();
Expand All @@ -1291,7 +1293,8 @@ class Parser {
ParserResult<Stmt> parseStmtForEach(SourceLoc ForLoc,
LabeledStmtInfo LabelInfo);
ParserResult<Stmt> parseStmtSwitch(LabeledStmtInfo LabelInfo);
ParserResult<CaseStmt> parseStmtCase();
ParserStatus parseStmtCases(SmallVectorImpl<ASTNode> &cases, bool IsActive);
ParserResult<CaseStmt> parseStmtCase(bool IsActive);

//===--------------------------------------------------------------------===//
// Generics Parsing
Expand Down
7 changes: 5 additions & 2 deletions lib/AST/ASTDumper.cpp
Expand Up @@ -1522,9 +1522,12 @@ class PrintStmt : public StmtVisitor<PrintStmt> {
void visitSwitchStmt(SwitchStmt *S) {
printCommon(S, "switch_stmt") << '\n';
printRec(S->getSubjectExpr());
for (CaseStmt *C : S->getCases()) {
for (auto N : S->getRawCases()) {
OS << '\n';
printRec(C);
if (N.is<Stmt*>())
printRec(N.get<Stmt*>());
else
printRec(N.get<Decl*>());
}
PrintWithColorRAII(OS, ParenthesisColor) << ')';
}
Expand Down
7 changes: 5 additions & 2 deletions lib/AST/ASTPrinter.cpp
Expand Up @@ -3359,8 +3359,11 @@ void PrintAST::visitSwitchStmt(SwitchStmt *stmt) {
// FIXME: print subject
Printer << "{";
Printer.printNewline();
for (CaseStmt *C : stmt->getCases()) {
visit(C);
for (auto N : stmt->getRawCases()) {
if (N.is<Stmt*>())
visit(cast<CaseStmt>(N.get<Stmt*>()));
else
visit(cast<IfConfigDecl>(N.get<Decl*>()));
}
Printer.printNewline();
indent();
Expand Down
19 changes: 13 additions & 6 deletions lib/AST/ASTWalker.cpp
Expand Up @@ -1452,12 +1452,19 @@ Stmt *Traversal::visitSwitchStmt(SwitchStmt *S) {
else
return nullptr;

for (CaseStmt *aCase : S->getCases()) {
if (Stmt *aStmt = doIt(aCase)) {
assert(aCase == aStmt && "switch case remap not supported");
(void)aStmt;
} else
return nullptr;
for (auto N : S->getRawCases()) {
if (Stmt *aCase = N.dyn_cast<Stmt*>()) {
assert(isa<CaseStmt>(aCase));
if (Stmt *aStmt = doIt(aCase)) {
assert(aCase == aStmt && "switch case remap not supported");
(void)aStmt;
} else
return nullptr;
} else {
assert(isa<IfConfigDecl>(N.get<Decl*>()));
if (doIt(N.get<Decl*>()))
return nullptr;
}
}

return S;
Expand Down
13 changes: 10 additions & 3 deletions lib/AST/Stmt.cpp
Expand Up @@ -412,15 +412,22 @@ CaseStmt *CaseStmt::create(ASTContext &C, SourceLoc CaseLoc,
SwitchStmt *SwitchStmt::create(LabeledStmtInfo LabelInfo, SourceLoc SwitchLoc,
Expr *SubjectExpr,
SourceLoc LBraceLoc,
ArrayRef<CaseStmt *> Cases,
ArrayRef<ASTNode> Cases,
SourceLoc RBraceLoc,
ASTContext &C) {
void *p = C.Allocate(totalSizeToAlloc<CaseStmt *>(Cases.size()),
#ifndef NDEBUG
for (auto N : Cases)
assert((N.is<Stmt*>() && isa<CaseStmt>(N.get<Stmt*>())) ||
(N.is<Decl*>() && isa<IfConfigDecl>(N.get<Decl*>())));
#endif

void *p = C.Allocate(totalSizeToAlloc<ASTNode>(Cases.size()),
alignof(SwitchStmt));
SwitchStmt *theSwitch = ::new (p) SwitchStmt(LabelInfo, SwitchLoc,
SubjectExpr, LBraceLoc,
Cases.size(), RBraceLoc);

std::uninitialized_copy(Cases.begin(), Cases.end(),
theSwitch->getTrailingObjects<CaseStmt *>());
theSwitch->getTrailingObjects<ASTNode>());
return theSwitch;
}
112 changes: 76 additions & 36 deletions lib/Parse/ParseStmt.cpp
Expand Up @@ -130,14 +130,25 @@ ParserStatus Parser::parseExprOrStmt(ASTNode &Result) {
return ResultExpr;
}

static bool isTerminatorForBraceItemListKind(const Token &Tok,
BraceItemListKind Kind,
ArrayRef<ASTNode> ParsedDecls) {
bool Parser::isTerminatorForBraceItemListKind(BraceItemListKind Kind,
ArrayRef<ASTNode> ParsedDecls) {
switch (Kind) {
case BraceItemListKind::Brace:
return false;
case BraceItemListKind::Case:
return Tok.is(tok::kw_case) || Tok.is(tok::kw_default);
if (Tok.is(tok::pound_if)) {
// '#if' here could be to guard 'case:' or statements in cases.
// If the next non-directive line starts with 'case' or 'default', it is
// for 'case's.
Parser::BacktrackingScope Backtrack(*this);
do {
consumeToken();
while (!Tok.isAtStartOfLine() && Tok.isNot(tok::eof))
skipSingle();
} while (Tok.isAny(tok::pound_if, tok::pound_elseif, tok::pound_else));
return Tok.isAny(tok::kw_case, tok::kw_default);
}
return Tok.isAny(tok::kw_case, tok::kw_default);
case BraceItemListKind::TopLevelCode:
// When parsing the top level executable code for a module, if we parsed
// some executable code, then we're done. We want to process (name bind,
Expand Down Expand Up @@ -247,7 +258,7 @@ ParserStatus Parser::parseBraceItems(SmallVectorImpl<ASTNode> &Entries,
Tok.isNot(tok::kw_sil_witness_table) &&
Tok.isNot(tok::kw_sil_default_witness_table) &&
(isConditionalBlock ||
!isTerminatorForBraceItemListKind(Tok, Kind, Entries))) {
!isTerminatorForBraceItemListKind(Kind, Entries))) {
if (Kind == BraceItemListKind::TopLevelLibrary &&
skipExtraTopLevelRBraces())
continue;
Expand Down Expand Up @@ -2151,36 +2162,20 @@ ParserResult<Stmt> Parser::parseStmtSwitch(LabeledStmtInfo LabelInfo) {
SourceLoc lBraceLoc = consumeToken(tok::l_brace);
SourceLoc rBraceLoc;

// If there are non-case-label statements at the start of the switch body,
// raise an error and recover by discarding them.
bool DiagnosedNotCoveredStmt = false;
while (!Tok.is(tok::kw_case) && !Tok.is(tok::kw_default)
&& !Tok.is(tok::r_brace) && !Tok.is(tok::eof)) {
if (!DiagnosedNotCoveredStmt) {
diagnose(Tok, diag::stmt_in_switch_not_covered_by_case);
DiagnosedNotCoveredStmt = true;
}
skipSingle();
}

SmallVector<CaseStmt*, 8> cases;
bool parsedDefault = false;
bool parsedBlockAfterDefault = false;
while (Tok.is(tok::kw_case) || Tok.is(tok::kw_default)) {
// We cannot have additional cases after a default clause. Complain on
// the first offender.
if (parsedDefault && !parsedBlockAfterDefault) {
parsedBlockAfterDefault = true;
diagnose(Tok, diag::case_after_default);
}

ParserResult<CaseStmt> Case = parseStmtCase();
Status |= Case;
if (Case.isNonNull()) {
cases.push_back(Case.get());
if (Case.get()->isDefault())
parsedDefault = true;
SmallVector<ASTNode, 8> cases;
Status |= parseStmtCases(cases, /*IsActive=*/true);

// We cannot have additional cases after a default clause. Complain on
// the first offender.
bool hasDefault = false;
for (auto Element : cases) {
if (!Element.is<Stmt*>()) continue;
auto *CS = cast<CaseStmt>(Element.get<Stmt*>());
if (hasDefault) {
diagnose(CS->getLoc(), diag::case_after_default);
break;
}
hasDefault |= CS->isDefault();
}

if (parseMatchingToken(tok::r_brace, rBraceLoc,
Expand All @@ -2193,6 +2188,51 @@ ParserResult<Stmt> Parser::parseStmtSwitch(LabeledStmtInfo LabelInfo) {
lBraceLoc, cases, rBraceLoc, Context));
}

ParserStatus
Parser::parseStmtCases(SmallVectorImpl<ASTNode> &cases, bool IsActive) {
ParserStatus Status;
while (Tok.isNot(tok::r_brace, tok::eof,
tok::pound_endif, tok::pound_elseif, tok::pound_else)) {
if (Tok.isAny(tok::kw_case, tok::kw_default)) {
ParserResult<CaseStmt> Case = parseStmtCase(IsActive);
Status |= Case;
if (Case.isNonNull())
cases.emplace_back(Case.get());
} else if (Tok.is(tok::pound_if)) {
// '#if' in 'case' position can enclose one or more 'case' or 'default'
// clauses.
auto IfConfigResult = parseIfConfig(
[&](SmallVectorImpl<ASTNode> &Elements, bool IsActive) {
parseStmtCases(Elements, IsActive);
});
Status |= IfConfigResult;
if (auto ICD = IfConfigResult.getPtrOrNull()) {
cases.emplace_back(ICD);

for (auto &Entry : ICD->getActiveClauseElements()) {
if (Entry.is<Decl*>() && isa<IfConfigDecl>(Entry.get<Decl*>()))
// Don't hoist nested '#if'.
continue;

assert(Entry.is<Stmt*>() && isa<CaseStmt>(Entry.get<Stmt*>()));
cases.push_back(Entry);
}
}
} else {
// If there are non-case-label statements at the start of the switch body,
// raise an error and recover by discarding them.
diagnose(Tok, diag::stmt_in_switch_not_covered_by_case);

while (Tok.isNot(tok::r_brace, tok::eof, tok::pound_elseif,
tok::pound_else, tok::pound_endif) &&
!isTerminatorForBraceItemListKind(BraceItemListKind::Case, {})) {
skipSingle();
}
}
}
return Status;
}

static ParserStatus parseStmtCase(Parser &P, SourceLoc &CaseLoc,
SmallVectorImpl<CaseLabelItem> &LabelItems,
SmallVectorImpl<VarDecl *> &BoundDecls,
Expand Down Expand Up @@ -2257,9 +2297,9 @@ parseStmtCaseDefault(Parser &P, SourceLoc &CaseLoc,
return Status;
}

ParserResult<CaseStmt> Parser::parseStmtCase() {
ParserResult<CaseStmt> Parser::parseStmtCase(bool IsActive) {
// A case block has its own scope for variables bound out of the pattern.
Scope S(this, ScopeKind::CaseVars);
Scope S(this, ScopeKind::CaseVars, !IsActive);

ParserStatus Status;

Expand Down
2 changes: 1 addition & 1 deletion lib/SILGen/SILGenPattern.cpp
Expand Up @@ -2616,7 +2616,7 @@ void SILGenFunction::emitSwitchStmt(SwitchStmt *S) {
// We use std::vector because it supports emplace_back; moving a ClauseRow is
// expensive.
std::vector<ClauseRow> clauseRows;
clauseRows.reserve(S->getCases().size());
clauseRows.reserve(S->getRawCases().size());
bool hasFallthrough = false;
for (auto caseBlock : S->getCases()) {
for (auto &labelItem : caseBlock->getCaseLabelItems()) {
Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/DerivedConformanceCodingKey.cpp
Expand Up @@ -271,7 +271,7 @@ deriveBodyCodingKey_enum_stringValue(AbstractFunctionDecl *strValDecl) {
body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt),
SourceLoc());
} else {
SmallVector<CaseStmt *, 4> cases;
SmallVector<ASTNode, 4> cases;
for (auto *elt : elements) {
auto *pat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
SourceLoc(), SourceLoc(),
Expand Down Expand Up @@ -336,7 +336,7 @@ deriveBodyCodingKey_init_stringValue(AbstractFunctionDecl *initDecl) {
}

auto *selfRef = createSelfDeclRef(initDecl);
SmallVector<CaseStmt *, 4> cases;
SmallVector<ASTNode, 4> cases;
for (auto *elt : elements) {
auto *litExpr = new (C) StringLiteralExpr(elt->getNameStr(), SourceRange(),
/*Implicit=*/true);
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/DerivedConformanceEquatableHashable.cpp
Expand Up @@ -88,7 +88,7 @@ static DeclRefExpr *convertEnumToIndex(SmallVectorImpl<ASTNode> &stmts,
indexPat, nullptr, funcDecl);

unsigned index = 0;
SmallVector<CaseStmt*, 4> cases;
SmallVector<ASTNode, 4> cases;
for (auto elt : enumDecl->getAllElements()) {
// generate: case .<Case>:
auto pat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/DerivedConformanceRawRepresentable.cpp
Expand Up @@ -93,7 +93,7 @@ static void deriveBodyRawRepresentable_raw(AbstractFunctionDecl *toRawDecl) {

Type enumType = parentDC->getDeclaredTypeInContext();

SmallVector<CaseStmt*, 4> cases;
SmallVector<ASTNode, 4> cases;
for (auto elt : enumDecl->getAllElements()) {
auto pat = new (C) EnumElementPattern(TypeLoc::withoutLoc(enumType),
SourceLoc(), SourceLoc(),
Expand Down Expand Up @@ -198,7 +198,7 @@ deriveBodyRawRepresentable_init(AbstractFunctionDecl *initDecl) {

auto selfDecl = cast<ConstructorDecl>(initDecl)->getImplicitSelfDecl();

SmallVector<CaseStmt*, 4> cases;
SmallVector<ASTNode, 4> cases;
for (auto elt : enumDecl->getAllElements()) {
auto litExpr = cloneRawLiteralExpr(C, elt->getRawValueExpr());
auto litPat = new (C) ExprPattern(litExpr, /*isResolved*/ true,
Expand Down
7 changes: 4 additions & 3 deletions lib/Sema/TypeCheckStmt.cpp
Expand Up @@ -863,11 +863,12 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
AddSwitchNest switchNest(*this);
AddLabeledStmt labelNest(*this, S);

for (unsigned i = 0, e = S->getCases().size(); i < e; ++i) {
auto *caseBlock = S->getCases()[i];
auto cases = S->getCases();
for (auto i = cases.begin(), e = cases.end(); i != e; ++i) {
auto *caseBlock = *i;
// Fallthrough transfers control to the next case block. In the
// final case block, it is invalid.
FallthroughDest = i+1 == e ? nullptr : S->getCases()[i+1];
FallthroughDest = std::next(i) == e ? nullptr : *std::next(i);

for (auto &labelItem : caseBlock->getMutableCaseLabelItems()) {
// Resolve the pattern in the label.
Expand Down
3 changes: 1 addition & 2 deletions lib/Sema/TypeCheckSwitchStmt.cpp
Expand Up @@ -917,8 +917,7 @@ namespace {
bool sawDowngradablePattern = false;
bool sawRedundantPattern = false;
SmallVector<Space, 4> spaces;
for (unsigned i = 0, e = Switch->getCases().size(); i < e; ++i) {
auto *caseBlock = Switch->getCases()[i];
for (auto *caseBlock : Switch->getCases()) {
for (auto &caseItem : caseBlock->getCaseLabelItems()) {
// 'where'-clauses on cases mean the case does not contribute to
// the exhaustiveness of the pattern.
Expand Down